Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
443ba8c1
Unverified
Commit
443ba8c1
authored
Dec 06, 2021
by
Yuge Zhang
Committed by
GitHub
Dec 06, 2021
Browse files
Serialization infrastructure V2 (#4337)
parent
896c516f
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
367 additions
and
148 deletions
+367
-148
docs/en_US/NAS/ApiReference.rst
docs/en_US/NAS/ApiReference.rst
+3
-1
docs/en_US/autotune_ref.rst
docs/en_US/autotune_ref.rst
+6
-0
examples/nas/multi-trial/nasbench101/network.py
examples/nas/multi-trial/nasbench101/network.py
+2
-2
examples/nas/multi-trial/nasbench201/network.py
examples/nas/multi-trial/nasbench201/network.py
+2
-2
nni/__init__.py
nni/__init__.py
+1
-1
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
+8
-8
nni/algorithms/hpo/hyperband_advisor.py
nni/algorithms/hpo/hyperband_advisor.py
+6
-6
nni/common/__init__.py
nni/common/__init__.py
+1
-0
nni/common/serializer.py
nni/common/serializer.py
+305
-94
nni/experiment/experiment.py
nni/experiment/experiment.py
+2
-2
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+1
-1
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+1
-1
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
+2
-2
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
+3
-4
nni/retiarii/evaluator/pytorch/cgo/trainer.py
nni/retiarii/evaluator/pytorch/cgo/trainer.py
+4
-3
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+12
-10
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+1
-1
nni/retiarii/integration.py
nni/retiarii/integration.py
+4
-4
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+0
-4
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+3
-2
No files found.
docs/en_US/NAS/ApiReference.rst
View file @
443ba8c1
...
...
@@ -114,7 +114,9 @@ CGO Execution
Utilities
---------
.. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.basic_unit
.. autofunction:: nni.retiarii.model_wrapper
.. autofunction:: nni.retiarii.fixed_arch
...
...
docs/en_US/autotune_ref.rst
View file @
443ba8c1
...
...
@@ -78,3 +78,9 @@ Utilities
---------
.. autofunction:: nni.utils.merge_parameter
.. autofunction:: nni.trace
.. autofunction:: nni.dump
.. autofunction:: nni.load
examples/nas/multi-trial/nasbench101/network.py
View file @
443ba8c1
...
...
@@ -3,7 +3,7 @@ import nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
torch.nn
as
nn
import
torchmetrics
from
nni.retiarii
import
model_wrapper
,
serialize
,
serialize_cls
from
nni.retiarii
import
model_wrapper
,
serialize
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.nn.pytorch
import
NasBench101Cell
from
nni.retiarii.strategy
import
Random
...
...
@@ -82,7 +82,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
@
nni
.
trace
class
NasBench101TrainingModule
(
pl
.
LightningModule
):
def
__init__
(
self
,
max_epochs
=
108
,
learning_rate
=
0.1
,
weight_decay
=
1e-4
):
super
().
__init__
()
...
...
examples/nas/multi-trial/nasbench201/network.py
View file @
443ba8c1
...
...
@@ -3,7 +3,7 @@ import nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
torch.nn
as
nn
import
torchmetrics
from
nni.retiarii
import
model_wrapper
,
serialize
,
serialize_cls
from
nni.retiarii
import
model_wrapper
,
serialize
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.nn.pytorch
import
NasBench201Cell
from
nni.retiarii.strategy
import
Random
...
...
@@ -71,7 +71,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
@
nni
.
trace
class
NasBench201TrainingModule
(
pl
.
LightningModule
):
def
__init__
(
self
,
max_epochs
=
200
,
learning_rate
=
0.1
,
weight_decay
=
5e-4
):
super
().
__init__
()
...
...
nni/__init__.py
View file @
443ba8c1
...
...
@@ -9,7 +9,7 @@ except ModuleNotFoundError:
from
.runtime.log
import
init_logger
init_logger
()
from
.common.serializer
import
*
from
.common.serializer
import
trace
,
dump
,
load
from
.runtime.env_vars
import
dispatcher_env_vars
from
.utils
import
ClassArgsValidator
...
...
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
View file @
443ba8c1
...
...
@@ -7,12 +7,12 @@ bohb_advisor.py
import
sys
import
math
import
logging
import
json_tricks
from
schema
import
Schema
,
Optional
import
ConfigSpace
as
CS
import
ConfigSpace.hyperparameters
as
CSH
from
ConfigSpace.read_and_write
import
pcs_new
import
nni
from
nni
import
ClassArgsValidator
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
...
...
@@ -428,7 +428,7 @@ class BOHB(MsgDispatcherBase):
'parameter_source'
:
'algorithm'
,
'parameters'
:
''
}
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
NoMoreTrialJobs
,
nni
.
dump
(
ret
))
return
None
assert
self
.
generated_hyper_configs
params
=
self
.
generated_hyper_configs
.
pop
(
0
)
...
...
@@ -459,7 +459,7 @@ class BOHB(MsgDispatcherBase):
"""
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
ret
))
self
.
credit
-=
1
def
handle_update_search_space
(
self
,
data
):
...
...
@@ -536,7 +536,7 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
hyper_params
=
json_tricks
.
load
s
(
data
[
'hyper_params'
])
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
...
...
@@ -551,7 +551,7 @@ class BOHB(MsgDispatcherBase):
ret
[
'parameter_index'
]
=
one_unsatisfied
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
ret
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
...
...
@@ -584,7 +584,7 @@ class BOHB(MsgDispatcherBase):
"""
logger
.
debug
(
'handle report metric data = %s'
,
data
)
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
load
s
(
data
[
'value'
])
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
...
...
@@ -599,7 +599,7 @@ class BOHB(MsgDispatcherBase):
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
else
:
assert
'value'
in
data
value
=
extract_scalar_reward
(
data
[
'value'
])
...
...
@@ -655,7 +655,7 @@ class BOHB(MsgDispatcherBase):
data doesn't have required key 'parameter' and 'value'
"""
for
entry
in
data
:
entry
[
'value'
]
=
json_tricks
.
load
s
(
entry
[
'value'
])
entry
[
'value'
]
=
nni
.
load
(
entry
[
'value'
])
_completed_num
=
0
for
trial_info
in
data
:
logger
.
info
(
"Importing data, current processing progress %s / %s"
,
_completed_num
,
len
(
data
))
...
...
nni/algorithms/hpo/hyperband_advisor.py
View file @
443ba8c1
...
...
@@ -10,10 +10,10 @@ import logging
import
math
import
sys
import
json_tricks
import
numpy
as
np
from
schema
import
Schema
,
Optional
import
nni
from
nni
import
ClassArgsValidator
from
nni.common.hpo_utils
import
validate_search_space
from
nni.runtime.common
import
multi_phase_enabled
...
...
@@ -336,7 +336,7 @@ class Hyperband(MsgDispatcherBase):
def
_request_one_trial_job
(
self
):
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
ret
))
self
.
credit
-=
1
def
_get_one_trial_job
(
self
):
...
...
@@ -365,7 +365,7 @@ class Hyperband(MsgDispatcherBase):
'parameter_source'
:
'algorithm'
,
'parameters'
:
''
}
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
NoMoreTrialJobs
,
nni
.
dump
(
ret
))
return
None
assert
self
.
generated_hyper_configs
...
...
@@ -408,7 +408,7 @@ class Hyperband(MsgDispatcherBase):
event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
hyper_params
=
json_tricks
.
load
s
(
data
[
'hyper_params'
])
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
...
...
@@ -426,7 +426,7 @@ class Hyperband(MsgDispatcherBase):
Data type not supported
"""
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
load
s
(
data
[
'value'
])
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
# multiphase? need to check
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
...
...
@@ -440,7 +440,7 @@ class Hyperband(MsgDispatcherBase):
if
data
[
'parameter_index'
]
is
not
None
:
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dump
s
(
ret
))
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
else
:
value
=
extract_scalar_reward
(
data
[
'value'
])
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
...
...
nni/common/__init__.py
View file @
443ba8c1
from
.serializer
import
trace
,
dump
,
load
,
is_traceable
nni/common/serializer.py
View file @
443ba8c1
import
abc
import
copy
import
collections.abc
import
base64
import
functools
import
inspect
import
numbers
import
types
import
warnings
from
io
import
IOBase
from
typing
import
Any
,
Union
,
Dict
,
Optional
,
List
,
TypeVar
import
json_tricks
# use json_tricks as serializer backend
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
__all__
=
[
'trace'
,
'dump'
,
'load'
,
'
SerializableObject
'
]
__all__
=
[
'trace'
,
'dump'
,
'load'
,
'
Translatable'
,
'Traceable'
,
'is_traceable
'
]
T
=
TypeVar
(
'T'
)
class
SerializableObject
:
class
Traceable
(
abc
.
ABC
):
"""
A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
Dict returns a TraceDictType to enable serialization.
"""
@
abc
.
abstractmethod
def
trace_copy
(
self
)
->
'Traceable'
:
"""
Perform a shallow copy.
NOTE: NONE of the attributes will be preserved.
This is the one that should be used when you want to "mutate" a serializable object.
"""
...
@
property
@
abc
.
abstractmethod
def
trace_symbol
(
self
)
->
Any
:
"""
Symbol object. Could be a class or a function.
``get_hybrid_cls_or_func_name`` and ``import_cls_or_func_from_hybrid_name`` is a pair to
convert the symbol into a string and convert the string back to symbol.
"""
...
@
property
@
abc
.
abstractmethod
def
trace_args
(
self
)
->
List
[
Any
]:
"""
List of positional arguments passed to symbol. Usually empty if ``kw_only`` is true,
in which case all the positional arguments are converted into keyword arguments.
"""
...
@
property
@
abc
.
abstractmethod
def
trace_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
"""
Dict of keyword arguments.
"""
...
class
Translatable
(
abc
.
ABC
):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@
abc
.
abstractmethod
def
_translate
(
self
)
->
Any
:
pass
@
staticmethod
def
_translate_argument
(
d
:
Any
)
->
Any
:
if
isinstance
(
d
,
Translatable
):
return
d
.
_translate
()
return
d
def
is_traceable
(
obj
:
Any
)
->
bool
:
"""
Check whether an object is a traceable instance (not type).
"""
return
hasattr
(
obj
,
'trace_copy'
)
and
\
hasattr
(
obj
,
'trace_symbol'
)
and
\
hasattr
(
obj
,
'trace_args'
)
and
\
hasattr
(
obj
,
'trace_kwargs'
)
and
\
not
inspect
.
isclass
(
obj
)
class
SerializableObject
(
Traceable
):
"""
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
"""
def
__init__
(
self
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
],
_self_contained
:
bool
=
False
):
def
__init__
(
self
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
],
call_super
:
bool
=
False
):
# use dict to avoid conflicts with user's getattr and setattr
self
.
__dict__
[
'_nni_symbol'
]
=
symbol
self
.
__dict__
[
'_nni_args'
]
=
args
self
.
__dict__
[
'_nni_kwargs'
]
=
kwargs
self
.
__dict__
[
'_nni_call_super'
]
=
call_super
self
.
__dict__
[
'_nni_self_contained'
]
=
_self_contained
if
_self_contained
:
# this is for internal usage only.
# kwargs is used to init the full object in the same object as this one, for simpler implementation.
super
().
__init__
(
*
self
.
_recursive_init
(
args
),
**
self
.
_recursive_init
(
kwargs
))
def
get
(
self
)
->
Any
:
"""
Get the original object.
"""
if
self
.
_get_nni_attr
(
'self_contained'
):
return
self
if
'_nni_cache'
not
in
self
.
__dict__
:
self
.
__dict__
[
'_nni_cache'
]
=
self
.
_get_nni_attr
(
'symbol'
)(
*
self
.
_recursive_init
(
self
.
_get_nni_attr
(
'args'
)),
**
self
.
_recursive_init
(
self
.
_get_nni_attr
(
'kwargs'
))
if
call_super
:
# call super means that the serializable object is by itself an object of the target class
super
().
__init__
(
*
[
_argument_processor
(
arg
)
for
arg
in
args
],
**
{
kw
:
_argument_processor
(
arg
)
for
kw
,
arg
in
kwargs
.
items
()}
)
return
self
.
__dict__
[
'_nni_cache'
]
def
copy
(
self
)
->
Union
[
T
,
'SerializableObject'
]:
"""
Perform a shallow copy. Will throw away the self-contain property for classes (refer to implementation).
This is the one that should be used when you want to "mutate" a serializable object.
"""
def
trace_copy
(
self
)
->
Union
[
T
,
'SerializableObject'
]:
return
SerializableObject
(
self
.
_get_nni_attr
(
'
symbol
'
)
,
self
.
_get_nni_attr
(
'
args
'
)
,
self
.
_get_nni_attr
(
'kwargs'
)
self
.
trace_
symbol
,
[
copy
.
copy
(
arg
)
for
arg
in
self
.
trace_
args
]
,
{
k
:
copy
.
copy
(
v
)
for
k
,
v
in
self
.
trace_kwargs
.
items
()},
)
def
__json_encode__
(
self
):
ret
=
{
'__symbol__'
:
_get_hybrid_cls_or_func_name
(
self
.
_get_nni_attr
(
'symbol'
))}
if
self
.
_get_nni_attr
(
'args'
):
ret
[
'__args__'
]
=
self
.
_get_nni_attr
(
'args'
)
ret
[
'__kwargs__'
]
=
self
.
_get_nni_attr
(
'kwargs'
)
return
ret
@
property
def
trace_symbol
(
self
)
->
Any
:
return
self
.
_get_nni_attr
(
'symbol'
)
@
trace_symbol
.
setter
def
trace_symbol
(
self
,
symbol
:
Any
)
->
None
:
# for mutation purposes
self
.
__dict__
[
'_nni_symbol'
]
=
symbol
@
property
def
trace_args
(
self
)
->
List
[
Any
]:
return
self
.
_get_nni_attr
(
'args'
)
def
_get_nni_attr
(
self
,
name
):
@
trace_args
.
setter
def
trace_args
(
self
,
args
:
List
[
Any
]):
self
.
__dict__
[
'_nni_args'
]
=
args
@
property
def
trace_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
return
self
.
_get_nni_attr
(
'kwargs'
)
@
trace_kwargs
.
setter
def
trace_kwargs
(
self
,
kwargs
:
Dict
[
str
,
Any
]):
self
.
__dict__
[
'_nni_kwargs'
]
=
kwargs
def
_get_nni_attr
(
self
,
name
:
str
)
->
Any
:
return
self
.
__dict__
[
'_nni_'
+
name
]
def
__repr__
(
self
):
if
self
.
_get_nni_attr
(
'self_contained'
):
return
repr
(
self
)
if
'_nni_cache'
in
self
.
__dict__
:
return
repr
(
self
.
_get_nni_attr
(
'cache'
))
if
self
.
_get_nni_attr
(
'call_super'
):
return
super
().
__repr__
()
return
'SerializableObject('
+
\
', '
.
join
([
'type='
+
self
.
_get_nni_attr
(
'symbol'
).
__name__
]
+
[
repr
(
d
)
for
d
in
self
.
_get_nni_attr
(
'args'
)]
+
[
k
+
'='
+
repr
(
v
)
for
k
,
v
in
self
.
_get_nni_attr
(
'kwargs'
).
items
()])
+
\
')'
@
staticmethod
def
_recursive_init
(
d
):
# auto-call get() to prevent type-converting in downstreaming functions
if
isinstance
(
d
,
dict
):
return
{
k
:
v
.
get
()
if
isinstance
(
v
,
SerializableObject
)
else
v
for
k
,
v
in
d
.
items
()}
else
:
return
[
v
.
get
()
if
isinstance
(
v
,
SerializableObject
)
else
v
for
v
in
d
]
def
inject_trace_info
(
obj
:
Any
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
def
getter_factory
(
x
):
return
lambda
self
:
self
.
__dict__
[
'_nni_'
+
x
]
def
setter_factory
(
x
):
def
setter
(
self
,
val
):
self
.
__dict__
[
'_nni_'
+
x
]
=
val
return
setter
def
trace_copy
(
self
):
return
SerializableObject
(
self
.
trace_symbol
,
[
copy
.
copy
(
arg
)
for
arg
in
self
.
trace_args
],
{
k
:
copy
.
copy
(
v
)
for
k
,
v
in
self
.
trace_kwargs
.
items
()},
)
attributes
=
{
'trace_symbol'
:
property
(
getter_factory
(
'symbol'
),
setter_factory
(
'symbol'
)),
'trace_args'
:
property
(
getter_factory
(
'args'
),
setter_factory
(
'args'
)),
'trace_kwargs'
:
property
(
getter_factory
(
'kwargs'
),
setter_factory
(
'kwargs'
)),
'trace_copy'
:
trace_copy
}
def
trace
(
cls_or_func
:
T
=
None
,
*
,
kw_only
:
bool
=
True
)
->
Union
[
T
,
SerializableObject
]:
if
hasattr
(
obj
,
'__class__'
)
and
hasattr
(
obj
,
'__dict__'
):
for
name
,
method
in
attributes
.
items
():
setattr
(
obj
.
__class__
,
name
,
method
)
else
:
wrapper
=
type
(
'wrapper'
,
(
Traceable
,
type
(
obj
)),
attributes
)
obj
=
wrapper
(
obj
)
# pylint: disable=abstract-class-instantiated
# make obj complying with the interface of traceable, though we cannot change its base class
obj
.
__dict__
.
update
(
_nni_symbol
=
symbol
,
_nni_args
=
args
,
_nni_kwargs
=
kwargs
)
return
obj
def
trace
(
cls_or_func
:
T
=
None
,
*
,
kw_only
:
bool
=
True
)
->
Union
[
T
,
Traceable
]:
"""
Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios:
...
...
@@ -98,16 +205,16 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab
When a class/function is annotated, all the instances/calls will return a object as it normally will.
Although the object might act like a normal object, it's actually a different object with NNI-specific properties.
To get the original object, you should use ``obj.get()`` to retrieve. The retrieved object can be used
like the original one, but there are still subtle differences in implementation.
Note that when using the result from a trace in another trace-able function/class, ``.get()`` is automatically
called, so that you don't have to worry about type-converting.
One exception is that if your function returns None, it will return an empty SerializableObject instead,
which should raise your attention when you want to check whether the None ``is None``.
Also it records extra information about where this object comes from. That's why it's called "trace".
When parameters of functions are received, it is first stored, and then a shallow copy will be passed to inner function.
This is to prevent mutable objects gets modified in the inner function.
When the function finished execution, we also record extra information about where this object comes from.
That's why it's called "trace".
When call ``nni.dump``, that information will be used, by default.
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspect the argument
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspect
ing
the argument
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
Example:
...
...
@@ -120,10 +227,18 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab
"""
def
wrap
(
cls_or_func
):
# already annotated, do nothing
if
getattr
(
cls_or_func
,
'_traced'
,
False
):
return
cls_or_func
if
isinstance
(
cls_or_func
,
type
):
return
_trace_cls
(
cls_or_func
,
kw_only
)
cls_or_func
=
_trace_cls
(
cls_or_func
,
kw_only
)
elif
_is_function
(
cls_or_func
):
cls_or_func
=
_trace_func
(
cls_or_func
,
kw_only
)
else
:
return
_trace_func
(
cls_or_func
,
kw_only
)
raise
TypeError
(
f
'
{
cls_or_func
}
of type
{
type
(
cls_or_func
)
}
is not supported to be traced. '
'File an issue at https://github.com/microsoft/nni/issues if you believe this is a mistake.'
)
cls_or_func
.
_traced
=
True
return
cls_or_func
# if we're being called as @trace()
if
cls_or_func
is
None
:
...
...
@@ -133,8 +248,8 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab
return
wrap
(
cls_or_func
)
def
dump
(
obj
:
Any
,
fp
:
Optional
[
Any
]
=
None
,
use_trace
:
bool
=
True
,
pickle_size_limit
:
int
=
4096
,
**
json_tricks_kwargs
)
->
Union
[
str
,
bytes
]:
def
dump
(
obj
:
Any
,
fp
:
Optional
[
Any
]
=
None
,
*
,
use_trace
:
bool
=
True
,
pickle_size_limit
:
int
=
4096
,
allow_nan
:
bool
=
True
,
**
json_tricks_kwargs
)
->
Union
[
str
,
bytes
]:
"""
Convert a nested data structure to a json string. Save to file if fp is specified.
Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle.
...
...
@@ -143,10 +258,14 @@ def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size
Parameters
----------
obj : any
The object to dump.
fp : file handler or path
File to write to. Keep it none if you want to dump a string.
pickle_size_limit : int
This is set to avoid too long serialization result. Set to -1 to disable size check.
allow_nan : bool
Whether to allow nan to be serialized. Different from default value in json-tricks, our default value is true.
json_tricks_kwargs : dict
Other keyword arguments passed to json tricks (backend), e.g., indent=2.
...
...
@@ -171,19 +290,32 @@ def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size
functools
.
partial
(
_json_tricks_any_object_encode
,
pickle_size_limit
=
pickle_size_limit
),
]
json_tricks_kwargs
[
'allow_nan'
]
=
allow_nan
if
fp
is
not
None
:
return
json_tricks
.
dump
(
obj
,
fp
,
obj_encoders
=
encoders
,
**
json_tricks_kwargs
)
else
:
return
json_tricks
.
dumps
(
obj
,
obj_encoders
=
encoders
,
**
json_tricks_kwargs
)
def
load
(
string
:
str
=
None
,
fp
:
Optional
[
Any
]
=
None
,
**
json_tricks_kwargs
)
->
Any
:
def
load
(
string
:
Optional
[
str
]
=
None
,
*
,
fp
:
Optional
[
Any
]
=
None
,
ignore_comments
:
bool
=
True
,
**
json_tricks_kwargs
)
->
Any
:
"""
Load the string or from file, and convert it to a complex data structure.
At least one of string or fp has to be not none.
Parameters
----------
string : str
JSON string to parse. Can be set to none if fp is used.
fp : str
File path to load JSON from. Can be set to none if string is used.
ignore_comments : bool
Remove comments (starting with ``#`` or ``//``). Default is true.
Returns
-------
any
The loaded object.
"""
assert
string
is
not
None
or
fp
is
not
None
# see encoders for explanation
...
...
@@ -201,7 +333,12 @@ def load(string: str = None, fp: Optional[Any] = None, **json_tricks_kwargs) ->
_json_tricks_any_object_decode
]
# to bypass a deprecation warning in json-tricks
json_tricks_kwargs
[
'ignore_comments'
]
=
ignore_comments
if
string
is
not
None
:
if
isinstance
(
string
,
IOBase
):
raise
TypeError
(
f
'Expect a string, found a
{
string
}
. If you intend to use a file, use `nni.load(fp=file)`'
)
return
json_tricks
.
loads
(
string
,
obj_pairs_hooks
=
hooks
,
**
json_tricks_kwargs
)
else
:
return
json_tricks
.
load
(
fp
,
obj_pairs_hooks
=
hooks
,
**
json_tricks_kwargs
)
...
...
@@ -214,11 +351,57 @@ def _trace_cls(base, kw_only):
class
wrapper
(
SerializableObject
,
base
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# store a copy of initial parameters
args
,
kwargs
=
_
get
_arguments
_as_dict
(
base
.
__init__
,
args
,
kwargs
,
kw_only
)
args
,
kwargs
=
_
formulate
_arguments
(
base
.
__init__
,
args
,
kwargs
,
kw_only
,
is_class_init
=
True
)
# calling serializable object init to initialize the full object
super
().
__init__
(
symbol
=
base
,
args
=
args
,
kwargs
=
kwargs
,
_self_contained
=
True
)
super
().
__init__
(
symbol
=
base
,
args
=
args
,
kwargs
=
kwargs
,
call_super
=
True
)
_copy_class_wrapper_attributes
(
base
,
wrapper
)
return
wrapper
def
_trace_func
(
func
,
kw_only
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
# similar to class, store parameters here
args
,
kwargs
=
_formulate_arguments
(
func
,
args
,
kwargs
,
kw_only
)
# it's not clear whether this wrapper can handle all the types in python
# There are many cases here: https://docs.python.org/3/reference/datamodel.html
# but it looks that we have handled most commonly used cases
res
=
func
(
*
[
_argument_processor
(
arg
)
for
arg
in
args
],
**
{
kw
:
_argument_processor
(
arg
)
for
kw
,
arg
in
kwargs
.
items
()}
)
if
res
is
None
:
# don't call super, makes no sense.
# an empty serializable object is "none". Don't check it though.
res
=
SerializableObject
(
func
,
args
,
kwargs
,
call_super
=
False
)
elif
hasattr
(
res
,
'__class__'
)
and
hasattr
(
res
,
'__dict__'
):
# is a class, inject interface directly
# need to be done before primitive types because there could be inheritance here.
res
=
inject_trace_info
(
res
,
func
,
args
,
kwargs
)
elif
isinstance
(
res
,
(
collections
.
abc
.
Callable
,
types
.
ModuleType
,
IOBase
)):
raise
TypeError
(
f
'Try to add trace info to
{
res
}
, but functions and modules are not supported.'
)
elif
isinstance
(
res
,
(
numbers
.
Number
,
collections
.
abc
.
Sequence
,
collections
.
abc
.
Set
,
collections
.
abc
.
Mapping
)):
# handle primitive types like int, str, set, dict, tuple
# NOTE: simple types including none, bool, int, float, list, tuple, dict
# will be directly captured by python json encoder
# and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation.
res
=
inject_trace_info
(
res
,
func
,
args
,
kwargs
)
else
:
raise
TypeError
(
f
'Try to add trace info to
{
res
}
, but the type "
{
type
(
res
)
}
" is unknown. '
'Please file an issue at https://github.com/microsoft/nni/issues'
)
return
res
return
wrapper
def
_copy_class_wrapper_attributes
(
base
,
wrapper
):
_MISSING
=
'_missing'
for
k
in
functools
.
WRAPPER_ASSIGNMENTS
:
# assign magic attributes like __module__, __qualname__, __doc__
...
...
@@ -229,25 +412,29 @@ def _trace_cls(base, kw_only):
except
AttributeError
:
pass
return
wrapper
wrapper
.
__wrapped__
=
base
def
_trace_func
(
func
,
kw_only
):
@
functools
.
wraps
def
wrapper
(
*
args
,
**
kwargs
):
# similar to class, store parameters here
args
,
kwargs
=
_get_arguments_as_dict
(
func
,
args
,
kwargs
,
kw_only
)
return
SerializableObject
(
func
,
args
,
kwargs
)
return
wrapper
def
_argument_processor
(
arg
):
# 1) translate
# handle cases like ValueChoice
# This is needed because sometimes the recorded arguments are meant to be different from what the inner object receives.
arg
=
Translatable
.
_translate_argument
(
arg
)
# 2) prevent the stored parameters to be mutated by inner class.
# an example: https://github.com/microsoft/nni/issues/4329
if
isinstance
(
arg
,
(
collections
.
abc
.
MutableMapping
,
collections
.
abc
.
MutableSequence
,
collections
.
abc
.
MutableSet
)):
arg
=
copy
.
copy
(
arg
)
return
arg
def
_get_arguments_as_dict
(
func
,
args
,
kwargs
,
kw_only
):
def
_formulate_arguments
(
func
,
args
,
kwargs
,
kw_only
,
is_class_init
=
False
):
# This is to formulate the arguments and make them well-formed.
if
kw_only
:
# get arguments passed to a function, and save it as a dict
argname_list
=
list
(
inspect
.
signature
(
func
).
parameters
.
keys
())[
1
:]
argname_list
=
list
(
inspect
.
signature
(
func
).
parameters
.
keys
())
if
is_class_init
:
argname_list
=
argname_list
[
1
:]
full_args
=
{}
full_args
.
update
(
kwargs
)
# match arguments with given arguments
# args should be longer than given list, because args can be used in a kwargs way
...
...
@@ -255,9 +442,18 @@ def _get_arguments_as_dict(func, args, kwargs, kw_only):
for
argname
,
value
in
zip
(
argname_list
,
args
):
full_args
[
argname
]
=
value
# use kwargs to override
full_args
.
update
(
kwargs
)
args
,
kwargs
=
[],
full_args
return
args
,
kwargs
return
list
(
args
),
kwargs
def
_is_function
(
obj
:
Any
)
->
bool
:
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return
isinstance
(
obj
,
(
types
.
FunctionType
,
types
.
BuiltinFunctionType
,
types
.
MethodType
,
types
.
BuiltinMethodType
))
def
_import_cls_or_func_from_name
(
target
:
str
)
->
Any
:
...
...
@@ -268,6 +464,12 @@ def _import_cls_or_func_from_name(target: str) -> Any:
return
getattr
(
module
,
identifier
)
def
_strip_trace_type
(
traceable
:
Any
)
->
Any
:
if
getattr
(
traceable
,
'_traced'
,
False
):
return
traceable
.
__wrapped__
return
traceable
def
_get_cls_or_func_name
(
cls_or_func
:
Any
)
->
str
:
module_name
=
cls_or_func
.
__module__
if
module_name
==
'__main__'
:
...
...
@@ -276,7 +478,8 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str:
try
:
imported
=
_import_cls_or_func_from_name
(
full_name
)
if
imported
!=
cls_or_func
:
# ignores the differences in trace
if
_strip_trace_type
(
imported
)
!=
_strip_trace_type
(
cls_or_func
):
raise
ImportError
(
f
'Imported
{
imported
}
is not same as expected. The function might be dynamically created.'
)
except
ImportError
:
raise
ImportError
(
f
'Import
{
cls_or_func
.
__name__
}
from "
{
module_name
}
" failed.'
)
...
...
@@ -284,12 +487,12 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str:
return
full_name
def
_
get_hybrid_cls_or_func_name
(
cls_or_func
:
Any
,
pickle_size_limit
:
int
=
4096
)
->
str
:
def
get_hybrid_cls_or_func_name
(
cls_or_func
:
Any
,
pickle_size_limit
:
int
=
4096
)
->
str
:
try
:
name
=
_get_cls_or_func_name
(
cls_or_func
)
# import success, use a path format
return
'path:'
+
name
except
ImportError
:
except
(
ImportError
,
AttributeError
)
:
b
=
cloudpickle
.
dumps
(
cls_or_func
)
if
len
(
b
)
>
pickle_size_limit
:
raise
ValueError
(
f
'Pickle too large when trying to dump
{
cls_or_func
}
. '
...
...
@@ -298,7 +501,7 @@ def _get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096
return
'bytes:'
+
base64
.
b64encode
(
b
).
decode
()
def
_
import_cls_or_func_from_hybrid_name
(
s
:
str
)
->
Any
:
def
import_cls_or_func_from_hybrid_name
(
s
:
str
)
->
Any
:
if
s
.
startswith
(
'bytes:'
):
b
=
base64
.
b64decode
(
s
.
split
(
':'
,
1
)[
-
1
])
return
cloudpickle
.
loads
(
b
)
...
...
@@ -308,40 +511,47 @@ def _import_cls_or_func_from_hybrid_name(s: str) -> Any:
def
_json_tricks_func_or_cls_encode
(
cls_or_func
:
Any
,
primitives
:
bool
=
False
,
pickle_size_limit
:
int
=
4096
)
->
str
:
if
not
isinstance
(
cls_or_func
,
type
)
and
not
callable
(
cls_or_func
):
if
not
isinstance
(
cls_or_func
,
type
)
and
not
_is_function
(
cls_or_func
):
# not a function or class, continue
return
cls_or_func
return
{
'__nni_type__'
:
_
get_hybrid_cls_or_func_name
(
cls_or_func
,
pickle_size_limit
)
'__nni_type__'
:
get_hybrid_cls_or_func_name
(
cls_or_func
,
pickle_size_limit
)
}
def
_json_tricks_func_or_cls_decode
(
s
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
s
,
dict
)
and
'__nni_type__'
in
s
:
s
=
s
[
'__nni_type__'
]
return
_
import_cls_or_func_from_hybrid_name
(
s
)
return
import_cls_or_func_from_hybrid_name
(
s
)
return
s
def
_json_tricks_serializable_object_encode
(
obj
:
Any
,
primitives
:
bool
=
False
,
use_trace
:
bool
=
True
)
->
Dict
[
str
,
Any
]:
# Encodes a serializable object instance to json.
# If primitives, the representation is simplified and cannot be recovered!
# do nothing to instance that is not a serializable object and do not use trace
if
not
use_trace
or
not
is
instance
(
obj
,
SerializableObject
):
if
not
use_trace
or
not
is
_traceable
(
obj
):
return
obj
return
obj
.
__json_encode__
()
if
isinstance
(
obj
.
trace_symbol
,
property
):
# commonly made mistake when users forget to call the traced function/class.
warnings
.
warn
(
f
'The symbol of
{
obj
}
is found to be a property. Did you forget to create the instance with ``xx(...)``?'
)
ret
=
{
'__symbol__'
:
get_hybrid_cls_or_func_name
(
obj
.
trace_symbol
)}
if
obj
.
trace_args
:
ret
[
'__args__'
]
=
obj
.
trace_args
if
obj
.
trace_kwargs
:
ret
[
'__kwargs__'
]
=
obj
.
trace_kwargs
return
ret
def
_json_tricks_serializable_object_decode
(
obj
:
Dict
[
str
,
Any
])
->
Any
:
if
isinstance
(
obj
,
dict
)
and
'__symbol__'
in
obj
and
'__kwargs__'
in
obj
:
return
SerializableObject
(
_import_cls_or_func_from_hybrid_name
(
obj
[
'__symbol__'
]),
getattr
(
obj
,
'__args__'
,
[]),
obj
[
'__kwargs__'
]
)
if
isinstance
(
obj
,
dict
)
and
'__symbol__'
in
obj
:
symbol
=
import_cls_or_func_from_hybrid_name
(
obj
[
'__symbol__'
])
args
=
obj
.
get
(
'__args__'
,
[])
kwargs
=
obj
.
get
(
'__kwargs__'
,
{})
return
trace
(
symbol
)(
*
args
,
**
kwargs
)
return
obj
...
...
@@ -353,8 +563,9 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
if
hasattr
(
obj
,
'__class__'
)
and
(
hasattr
(
obj
,
'__dict__'
)
or
hasattr
(
obj
,
'__slots__'
)):
b
=
cloudpickle
.
dumps
(
obj
)
if
len
(
b
)
>
pickle_size_limit
:
raise
ValueError
(
f
'Pickle too large when trying to dump
{
obj
}
. '
'Please try to raise pickle_size_limit if you insist.'
)
raise
ValueError
(
f
'Pickle too large when trying to dump
{
obj
}
. This might be caused by classes that are '
'not decorated by @nni.trace. Another option is to force bytes pickling and '
'try to raise pickle_size_limit.'
)
# use base64 to dump a bytes array
return
{
'__nni_obj__'
:
base64
.
b64encode
(
b
).
decode
()
...
...
nni/experiment/experiment.py
View file @
443ba8c1
...
...
@@ -6,11 +6,11 @@ from subprocess import Popen
import
time
from
typing
import
Optional
,
Union
,
List
,
overload
,
Any
import
json_tricks
import
colorama
import
psutil
import
nni.runtime.log
from
nni.common
import
dump
from
.config
import
ExperimentConfig
,
AlgorithmConfig
from
.data
import
TrialJob
,
TrialMetricData
,
TrialResult
...
...
@@ -439,7 +439,7 @@ class Experiment:
value: dict
New search_space.
"""
value
=
json_tricks
.
dump
s
(
value
)
value
=
dump
(
value
)
self
.
_update_experiment_profile
(
'searchSpace'
,
value
)
def
update_max_trial_number
(
self
,
value
:
int
):
...
...
nni/retiarii/__init__.py
View file @
443ba8c1
...
...
@@ -6,4 +6,4 @@ from .graph import *
from
.execution
import
*
from
.fixed
import
fixed_arch
from
.mutator
import
*
from
.serializer
import
basic_unit
,
json_dump
,
json_dumps
,
json_load
,
json_loads
,
serialize
,
serialize_cls
,
model_wrapper
from
.serializer
import
basic_unit
,
model_wrapper
,
serialize
,
serialize_cls
nni/retiarii/converter/graph_gen.py
View file @
443ba8c1
...
...
@@ -637,7 +637,7 @@ class GraphConverter:
original_type_name
not
in
MODULE_EXCEPT_LIST
:
# this is a basic module from pytorch, no need to parse its graph
m_attrs
=
get_init_parameters_or_fail
(
module
)
elif
getattr
(
module
,
'_
stop_parsing
'
,
False
):
elif
getattr
(
module
,
'_
nni_basic_unit
'
,
False
):
# this module is marked as serialize, won't continue to parse
m_attrs
=
get_init_parameters_or_fail
(
module
)
if
m_attrs
is
not
None
:
...
...
nni/retiarii/evaluator/pytorch/cgo/accelerator.py
View file @
443ba8c1
...
...
@@ -10,7 +10,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin
from
pytorch_lightning.trainer
import
Trainer
from
pytorch_lightning.trainer.connectors.accelerator_connector
import
AcceleratorConnector
from
....serializer
import
serialize_cls
import
nni
class
BypassPlugin
(
TrainingTypePlugin
):
...
...
@@ -126,7 +126,7 @@ def get_accelerator_connector(
)
@
serialize_cls
@
nni
.
trace
class
BypassAccelerator
(
Accelerator
):
def
__init__
(
self
,
precision_plugin
=
None
,
device
=
"cpu"
,
**
trainer_kwargs
):
if
precision_plugin
is
None
:
...
...
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
View file @
443ba8c1
...
...
@@ -14,10 +14,9 @@ import nni
from
..lightning
import
LightningModule
,
_AccuracyWithLogits
,
Lightning
from
.trainer
import
Trainer
from
....serializer
import
serialize_cls
@
serialize_cls
@
nni
.
trace
class
_MultiModelSupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torchmetrics
.
Metric
],
n_models
:
int
=
0
,
...
...
@@ -126,7 +125,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
super
().
__init__
(
criterion
,
metrics
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
)
@
serialize_cls
@
nni
.
trace
class
_ClassificationModule
(
MultiModelSupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
...
...
@@ -174,7 +173,7 @@ class Classification(Lightning):
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
@
serialize_cls
@
nni
.
trace
class
_RegressionModule
(
MultiModelSupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
...
...
nni/retiarii/evaluator/pytorch/cgo/trainer.py
View file @
443ba8c1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
pytorch_lightning
as
pl
from
....serializer
import
serialize_cls
import
nni
from
.accelerator
import
BypassAccelerator
@
serialize_cls
@
nni
.
trace
class
Trainer
(
pl
.
Trainer
):
"""
Trainer for cross-graph optimization.
...
...
nni/retiarii/evaluator/pytorch/lightning.py
View file @
443ba8c1
...
...
@@ -10,17 +10,17 @@ import pytorch_lightning as pl
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torchmetrics
from
torch.utils.data
import
DataLoader
import
torch.utils.data
as
torch_data
import
nni
from
nni.common.serializer
import
is_traceable
try
:
from
.cgo
import
trainer
as
cgo_trainer
cgo_import_failed
=
False
except
ImportError
:
cgo_import_failed
=
True
from
...graph
import
Evaluator
from
...serializer
import
serialize_cls
from
nni.retiarii.graph
import
Evaluator
__all__
=
[
'LightningModule'
,
'Trainer'
,
'DataLoader'
,
'Lightning'
,
'Classification'
,
'Regression'
]
...
...
@@ -40,9 +40,10 @@ class LightningModule(pl.LightningModule):
self
.
model
=
model
Trainer
=
serialize_cls
(
pl
.
Trainer
)
DataLoader
=
serialize_cls
(
DataLoader
)
Trainer
=
nni
.
trace
(
pl
.
Trainer
)
DataLoader
=
nni
.
trace
(
torch_data
.
DataLoader
)
@
nni
.
trace
class
Lightning
(
Evaluator
):
"""
Delegate the whole training to PyTorch Lightning.
...
...
@@ -74,9 +75,10 @@ class Lightning(Evaluator):
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
):
assert
isinstance
(
lightning_module
,
LightningModule
),
f
'Lightning module must be an instance of
{
__name__
}
.LightningModule.'
if
cgo_import_failed
:
assert
isinstance
(
trainer
,
T
rainer
),
f
'Trainer must be imported from
{
__name__
}
'
assert
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
t
rainer
),
f
'Trainer must be imported from
{
__name__
}
'
else
:
assert
isinstance
(
trainer
,
Trainer
)
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert
(
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
))
or
isinstance
(
trainer
,
cgo_trainer
.
Trainer
),
\
f
'Trainer must be imported from
{
__name__
}
or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert
_check_dataloader
(
train_dataloader
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
assert
_check_dataloader
(
val_dataloaders
),
f
'Wrong dataloader type. Try import DataLoader from
{
__name__
}
.'
...
...
@@ -135,7 +137,7 @@ def _check_dataloader(dataloader):
return
True
if
isinstance
(
dataloader
,
list
):
return
all
([
_check_dataloader
(
d
)
for
d
in
dataloader
])
return
isinstance
(
dataloader
,
DataL
oader
)
return
isinstance
(
dataloader
,
torch_data
.
DataLoader
)
and
is_traceable
(
datal
oader
)
### The following are some commonly used Lightning modules ###
...
...
@@ -219,7 +221,7 @@ class _AccuracyWithLogits(torchmetrics.Accuracy):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
@
nni
.
trace
class
_ClassificationModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
...
...
@@ -272,7 +274,7 @@ class Classification(Lightning):
train_dataloader
=
train_dataloader
,
val_dataloaders
=
val_dataloaders
)
@
serialize_cls
@
nni
.
trace
class
_RegressionModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
...
...
nni/retiarii/execution/cgo_engine.py
View file @
443ba8c1
...
...
@@ -200,7 +200,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params
=
model
.
evaluator
.
module
.
_init_parameter
s
.
copy
()
new_module_init_params
=
model
.
evaluator
.
module
.
trace_kwarg
s
.
copy
()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params
[
'n_models'
]
=
len
(
multi_model
)
...
...
nni/retiarii/integration.py
View file @
443ba8c1
...
...
@@ -4,13 +4,13 @@
import
logging
from
typing
import
Any
,
Callable
import
nni
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.utils
import
MetricType
from
.graph
import
MetricData
from
.integration_api
import
register_advisor
from
.serializer
import
json_dumps
,
json_loads
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -121,7 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'placement_constraint'
:
placement_constraint
}
_logger
.
debug
(
'New trial sent: %s'
,
new_trial
)
send
(
CommandType
.
NewTrialJob
,
json_
dump
s
(
new_trial
))
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
new_trial
))
if
self
.
send_trial_callback
is
not
None
:
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
return
self
.
parameters_count
...
...
@@ -140,7 +140,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
def
handle_trial_end
(
self
,
data
):
_logger
.
debug
(
'Trial end: %s'
,
data
)
self
.
trial_end_callback
(
json_
load
s
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
self
.
trial_end_callback
(
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
...
...
@@ -156,7 +156,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@
staticmethod
def
_process_value
(
value
)
->
Any
:
# hopefully a float
value
=
json_
load
s
(
value
)
value
=
nni
.
load
(
value
)
if
isinstance
(
value
,
dict
):
if
'default'
in
value
:
return
value
[
'default'
]
...
...
nni/retiarii/integration_api.py
View file @
443ba8c1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
from
typing
import
NewType
,
Any
import
nni
from
.serializer
import
json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
...
...
@@ -41,7 +38,6 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params
=
nni
.
get_next_parameter
()
params
=
json_loads
(
json
.
dumps
(
params
))
return
params
...
...
nni/retiarii/nn/pytorch/api.py
View file @
443ba8c1
...
...
@@ -8,8 +8,9 @@ from typing import Any, List, Union, Dict, Optional
import
torch
import
torch.nn
as
nn
from
...serializer
import
Translatable
,
basic_unit
from
...utils
import
NoContextError
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
NoContextError
from
.utils
import
generate_new_label
,
get_fixed_value
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment