Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2d8f925b
Unverified
Commit
2d8f925b
authored
Apr 12, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 12, 2022
Browse files
Bug fix of Retiarii hyperparameter mutation (#4751)
parent
c22dc0fc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
92 additions
and
20 deletions
+92
-20
nni/common/serializer.py
nni/common/serializer.py
+32
-6
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+3
-1
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+12
-12
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+23
-1
test/ut/sdk/test_serializer.py
test/ut/sdk/test_serializer.py
+22
-0
No files found.
nni/common/serializer.py
View file @
2d8f925b
...
...
@@ -66,6 +66,12 @@ class Traceable:
"""
raise
NotImplementedError
()
def
get
(
self
)
->
Any
:
"""
Get the original object. Usually used together with ``trace_copy``.
"""
raise
NotImplementedError
()
class
Translatable
(
abc
.
ABC
):
"""
...
...
@@ -136,6 +142,13 @@ class SerializableObject(Traceable):
{
k
:
copy
.
copy
(
v
)
for
k
,
v
in
self
.
trace_kwargs
.
items
()},
)
def
get
(
self
)
->
T
:
if
not
self
.
_get_nni_attr
(
'call_super'
):
# Reinitialize
return
trace
(
self
.
trace_symbol
)(
*
self
.
trace_args
,
**
self
.
trace_kwargs
)
return
self
@
property
def
trace_symbol
(
self
)
->
Any
:
return
self
.
_get_nni_attr
(
'symbol'
)
...
...
@@ -202,11 +215,15 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
{
k
:
copy
.
copy
(
v
)
for
k
,
v
in
self
.
trace_kwargs
.
items
()},
)
def
get
(
self
):
return
self
attributes
=
{
'trace_symbol'
:
property
(
getter_factory
(
'symbol'
),
setter_factory
(
'symbol'
)),
'trace_args'
:
property
(
getter_factory
(
'args'
),
setter_factory
(
'args'
)),
'trace_kwargs'
:
property
(
getter_factory
(
'kwargs'
),
setter_factory
(
'kwargs'
)),
'trace_copy'
:
trace_copy
'trace_copy'
:
trace_copy
,
'get'
:
get
,
}
if
not
create_wrapper
:
...
...
@@ -562,13 +579,13 @@ class _pickling_object:
# Used in `_trace_cls`.
def
__new__
(
cls
,
type_
,
kw_only
,
data
):
type_
=
cloudpickle
.
loads
(
type_
)
type_
=
_wrapped_
cloudpickle
_
loads
(
type_
)
# Restore the trace type
type_
=
_trace_cls
(
type_
,
kw_only
)
# restore type
if
'_nni_symbol'
in
data
:
data
[
'_nni_symbol'
]
=
cloudpickle
.
loads
(
data
[
'_nni_symbol'
])
data
[
'_nni_symbol'
]
=
_wrapped_
cloudpickle
_
loads
(
data
[
'_nni_symbol'
])
# https://docs.python.org/3/library/pickle.html#pickling-class-instances
obj
=
type_
.
__new__
(
type_
)
...
...
@@ -674,7 +691,7 @@ def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
def
_is_function
(
obj
:
Any
)
->
bool
:
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return
isinstance
(
obj
,
(
types
.
FunctionType
,
types
.
BuiltinFunctionType
,
types
.
MethodType
,
types
.
BuiltinMethodType
))
types
.
BuiltinMethodType
))
and
obj
is
not
None
def
_import_cls_or_func_from_name
(
target
:
str
)
->
Any
:
...
...
@@ -727,7 +744,7 @@ def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096)
def
import_cls_or_func_from_hybrid_name
(
s
:
str
)
->
Any
:
if
s
.
startswith
(
'bytes:'
):
b
=
base64
.
b64decode
(
s
.
split
(
':'
,
1
)[
-
1
])
return
cloudpickle
.
loads
(
b
)
return
_wrapped_
cloudpickle
_
loads
(
b
)
if
s
.
startswith
(
'path:'
):
s
=
s
.
split
(
':'
,
1
)[
-
1
]
return
_import_cls_or_func_from_name
(
s
)
...
...
@@ -800,5 +817,14 @@ def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any:
if
isinstance
(
obj
,
dict
)
and
'__nni_obj__'
in
obj
:
obj
=
obj
[
'__nni_obj__'
]
b
=
base64
.
b64decode
(
obj
)
return
cloudpickle
.
loads
(
b
)
return
_wrapped_
cloudpickle
_
loads
(
b
)
return
obj
def
_wrapped_cloudpickle_loads
(
b
:
bytes
)
->
Any
:
try
:
return
cloudpickle
.
loads
(
b
)
except
TypeError
:
warnings
.
warn
(
'TypeError encountered during deserializing object. This could be caused by '
'inconsistency between Python versions where dump and load happens.'
)
raise
nni/retiarii/execution/api.py
View file @
2d8f925b
...
...
@@ -21,7 +21,9 @@ def set_execution_engine(engine: AbstractExecutionEngine) -> None:
if
_execution_engine
is
None
:
_execution_engine
=
engine
else
:
raise
RuntimeError
(
'Execution engine is already set.'
)
raise
RuntimeError
(
'Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.'
)
def
get_execution_engine
()
->
AbstractExecutionEngine
:
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
2d8f925b
...
...
@@ -364,27 +364,27 @@ class EvaluatorValueChoiceMutator(Mutator):
if
not
is_traceable
(
obj
):
return
obj
if
not
any
(
isinstance
(
value
,
ValueChoiceX
)
for
value
in
obj
.
trace_kwargs
.
values
()):
# No valuechoice, not interesting
return
obj
# Make a copy
obj
=
obj
.
trace_copy
()
result
=
{}
updates
=
{}
# For each argument that is a composition of value choice
# we find all the leaf-value-choice in the mutation
# and compute the final
result
# and compute the final
updates
for
key
,
param
in
obj
.
trace_kwargs
.
items
():
if
isinstance
(
param
,
ValueChoiceX
):
leaf_node_values
=
[
value_choice_decisions
[
choice
.
label
]
for
choice
in
param
.
inner_choices
()]
result
[
key
]
=
param
.
evaluate
(
leaf_node_values
)
updates
[
key
]
=
param
.
evaluate
(
leaf_node_values
)
elif
is_traceable
(
param
):
# Recursively
result
[
key
]
=
self
.
_mutate_traceable_object
(
param
,
value_choice_decisions
)
sub_update
=
self
.
_mutate_traceable_object
(
param
,
value_choice_decisions
)
if
sub_update
is
not
param
:
# if mutated
updates
[
key
]
=
sub_update
if
updates
:
mutated_obj
=
obj
.
trace_copy
()
# Make a copy
mutated_obj
.
trace_kwargs
.
update
(
updates
)
# Mutate
mutated_obj
=
mutated_obj
.
get
()
# Instantiate the full mutated object
obj
.
trace_kwargs
.
update
(
result
)
return
mutated_obj
return
obj
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
2d8f925b
...
...
@@ -6,7 +6,9 @@ from collections import Counter
import
pytest
import
nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
nni.retiarii.nn.pytorch
as
nn
import
pytorch_lightning
import
torch
import
torch.nn.functional
as
F
from
nni.retiarii
import
InvalidMutation
,
Sampler
,
basic_unit
...
...
@@ -1202,11 +1204,31 @@ class Shared(unittest.TestCase):
samplers
=
[
RandomSampler
()
for
_
in
range
(
3
)]
for
_
in
range
(
10
):
model
=
_apply_all_mutators
(
init_model
,
mutators
,
samplers
)
a
,
v
=
model
.
evaluator
.
trace_kwargs
[
't'
].
trace_kwargs
[
'a'
]
,
model
.
evaluator
.
trace_kwargs
[
'v'
]
a
,
v
=
model
.
evaluator
.
trace_kwargs
[
't'
].
a
,
model
.
evaluator
.
trace_kwargs
[
'v'
]
assert
v
%
10
==
a
assert
a
in
[
1
,
2
,
3
]
assert
v
//
10
in
[
1
,
2
,
3
]
@
unittest
.
skipIf
(
pytorch_lightning
.
__version__
<
'1.0'
,
'Legacy PyTorch-lightning not supported'
)
def
test_valuechoice_lightning
(
self
):
@
nni
.
trace
class
AnyModule
(
pl
.
LightningModule
):
pass
evaluator
=
pl
.
Lightning
(
AnyModule
(),
pl
.
Trainer
(
max_epochs
=
nn
.
ValueChoice
([
1
,
2
,
3
])))
mutators
=
process_evaluator_mutations
(
evaluator
,
[])
assert
len
(
mutators
)
==
2
init_model
=
Model
(
_internal
=
True
)
init_model
.
evaluator
=
evaluator
samplers
=
[
RandomSampler
()
for
_
in
range
(
2
)]
values
=
[]
for
_
in
range
(
20
):
model
=
_apply_all_mutators
(
init_model
,
mutators
,
samplers
)
values
.
append
(
model
.
evaluator
.
trainer
.
max_epochs
)
model
.
_dump
()
assert
len
(
set
(
values
))
==
3
def
test_retiarii_nn_import
(
self
):
dummy
=
torch
.
zeros
(
1
,
16
,
32
,
24
)
nn
.
init
.
uniform_
(
dummy
)
...
...
test/ut/sdk/test_serializer.py
View file @
2d8f925b
...
...
@@ -353,3 +353,25 @@ def test_subclass():
assert
obj
.
trace_kwargs
==
{
'c'
:
1
,
'd'
:
2
}
assert
issubclass
(
type
(
obj
),
Super
)
assert
isinstance
(
obj
,
Super
)
def
test_get
():
@
nni
.
trace
class
Foo
:
def
__init__
(
self
,
a
=
1
):
self
.
_a
=
a
def
bar
(
self
):
return
self
.
_a
+
1
obj
=
Foo
(
3
)
assert
nni
.
load
(
nni
.
dump
(
obj
)).
bar
()
==
4
obj1
=
obj
.
trace_copy
()
with
pytest
.
raises
(
AttributeError
):
obj1
.
bar
()
obj1
.
trace_kwargs
[
'a'
]
=
5
obj1
=
obj1
.
get
()
assert
obj1
.
bar
()
==
6
obj2
=
obj1
.
trace_copy
()
obj2
.
trace_kwargs
[
'a'
]
=
-
1
assert
obj2
.
get
().
bar
()
==
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment