Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
443ba8c1
Unverified
Commit
443ba8c1
authored
Dec 06, 2021
by
Yuge Zhang
Committed by
GitHub
Dec 06, 2021
Browse files
Serialization infrastructure V2 (#4337)
parent
896c516f
Changes
40
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
440 additions
and
536 deletions
+440
-536
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+11
-9
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+1
-2
nni/retiarii/nn/pytorch/utils.py
nni/retiarii/nn/pytorch/utils.py
+2
-2
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+90
-125
nni/retiarii/utils.py
nni/retiarii/utils.py
+45
-0
nni/runtime/msg_dispatcher.py
nni/runtime/msg_dispatcher.py
+9
-9
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+3
-3
nni/runtime/platform/local.py
nni/runtime/platform/local.py
+3
-4
nni/runtime/platform/standalone.py
nni/runtime/platform/standalone.py
+3
-2
nni/runtime/platform/test.py
nni/runtime/platform/test.py
+8
-3
nni/tools/nnictl/config_utils.py
nni/tools/nnictl/config_utils.py
+4
-4
nni/trial.py
nni/trial.py
+5
-5
nni/utils.py
nni/utils.py
+0
-4
test/ut/retiarii/inject_nn.py
test/ut/retiarii/inject_nn.py
+11
-253
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+4
-5
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+1
-1
test/ut/retiarii/test_namespace.py
test/ut/retiarii/test_namespace.py
+88
-0
test/ut/retiarii/test_serializer.py
test/ut/retiarii/test_serializer.py
+0
-96
test/ut/sdk/imported/model.py
test/ut/sdk/imported/model.py
+0
-0
test/ut/sdk/test_serializer.py
test/ut/sdk/test_serializer.py
+152
-9
No files found.
nni/retiarii/nn/pytorch/mutator.py
View file @
443ba8c1
...
...
@@ -6,11 +6,13 @@ from typing import Any, List, Optional, Tuple
import
torch.nn
as
nn
from
...mutator
import
Mutator
from
...graph
import
Cell
,
Graph
,
Model
,
ModelStatus
,
Node
from
nni.retiarii.graph
import
Cell
,
Graph
,
Model
,
ModelStatus
,
Node
from
nni.retiarii.mutator
import
Mutator
from
nni.retiarii.serializer
import
is_basic_unit
from
nni.retiarii.utils
import
uid
from
.api
import
LayerChoice
,
InputChoice
,
ValueChoice
,
Placeholder
from
.component
import
Repeat
,
NasBench101Cell
,
NasBench101Mutator
from
...utils
import
uid
class
LayerChoiceMutator
(
Mutator
):
...
...
@@ -221,17 +223,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
graph
=
Graph
(
model
,
uid
(),
'_model'
,
_internal
=
True
).
_register
()
model
.
python_class
=
pytorch_model
.
__class__
if
len
(
inspect
.
signature
(
model
.
python_class
.
__init__
).
parameters
)
>
1
:
if
not
has
attr
(
pytorch_model
,
'_
i
ni
t_parameters'
):
raise
ValueError
(
'Please annotate the model with @
serialize
decorator in python execution mode '
if
not
get
attr
(
pytorch_model
,
'_
n
ni
_model_wrapper'
,
False
):
raise
ValueError
(
'Please annotate the model with @
model_wrapper
decorator in python execution mode '
'if your model has init parameters.'
)
model
.
python_init_params
=
pytorch_model
.
_init_parameter
s
model
.
python_init_params
=
pytorch_model
.
trace_kwarg
s
else
:
model
.
python_init_params
=
{}
for
name
,
module
in
pytorch_model
.
named_modules
():
# tricky case: value choice that serves as parameters are stored in
_init_parameter
s
if
hasattr
(
module
,
'_init_parameters'
):
for
key
,
value
in
module
.
_init_parameter
s
.
items
():
# tricky case: value choice that serves as parameters are stored in
traced argument
s
if
is_basic_unit
(
module
):
for
key
,
value
in
module
.
trace_kwarg
s
.
items
():
if
isinstance
(
value
,
ValueChoice
):
node
=
graph
.
add_node
(
name
+
'.init.'
+
key
,
'ValueChoice'
,
{
'candidates'
:
value
.
candidates
})
node
.
label
=
value
.
label
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
443ba8c1
...
...
@@ -5,7 +5,6 @@ import torch
import
torch.nn
as
nn
from
...serializer
import
basic_unit
from
...serializer
import
transparent_serialize
from
...utils
import
version_larger_equal
# NOTE: support pytorch version >= 1.5.0
...
...
@@ -42,7 +41,7 @@ if version_larger_equal(torch.__version__, '1.7.0'):
Module
=
nn
.
Module
Sequential
=
nn
.
Sequential
ModuleList
=
transparent_serialize
(
nn
.
ModuleList
)
ModuleList
=
basic_unit
(
nn
.
ModuleList
,
basic_unit_tag
=
False
)
Identity
=
basic_unit
(
nn
.
Identity
)
Linear
=
basic_unit
(
nn
.
Linear
)
...
...
nni/retiarii/nn/pytorch/utils.py
View file @
443ba8c1
from
typing
import
Any
,
Optional
,
Tuple
from
..
.utils
import
uid
,
get_current_context
from
nni.retiarii
.utils
import
ModelNamespace
,
get_current_context
def
generate_new_label
(
label
:
Optional
[
str
]):
if
label
is
None
:
return
'_mutation_'
+
str
(
uid
(
'mutation'
)
)
return
ModelNamespace
.
next_label
(
)
return
label
...
...
nni/retiarii/serializer.py
View file @
443ba8c1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
abc
import
functools
import
inspect
import
type
s
from
typing
import
Any
import
warning
s
from
typing
import
Any
,
TypeVar
,
Union
import
json_tricks
from
nni.common.serializer
import
Traceable
,
is_traceable
,
trace
,
_copy_class_wrapper_attributes
from
.utils
import
ModelNamespace
from
.utils
import
get_importable_name
,
get_module_name
,
import_
,
reset_uid
__all__
=
[
'get_init_parameters_or_fail'
,
'serialize'
,
'serialize_cls'
,
'basic_unit'
,
'model_wrapper'
,
'is_basic_unit'
,
'is_model_wrapped'
]
T
=
TypeVar
(
'T'
)
def
get_init_parameters_or_fail
(
obj
,
silently
=
False
):
if
hasattr
(
obj
,
'_init_parameters'
):
return
obj
.
_init_parameters
elif
silently
:
return
None
else
:
raise
ValueError
(
f
'Object
{
obj
}
needs to be serializable but `_init_parameters` is not available. '
def
get_init_parameters_or_fail
(
obj
:
Any
):
if
is_traceable
(
obj
):
return
obj
.
trace_kwargs
raise
ValueError
(
f
'Object
{
obj
}
needs to be serializable but `trace_kwargs` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use serialize or @serialize_cls.'
)
### This is a patch of json-tricks to make it more useful to us ###
'try to use @nni.trace.'
)
def
_serialize_class_instance_encode
(
obj
,
primitives
=
False
):
assert
not
primitives
,
'Encoding with primitives is not supported.'
try
:
# FIXME: raise error
if
hasattr
(
obj
,
'__class__'
):
return
{
'__type__'
:
get_importable_name
(
obj
.
__class__
),
'arguments'
:
get_init_parameters_or_fail
(
obj
)
}
except
ValueError
:
pass
return
obj
def
_serialize_class_instance_decode
(
obj
):
if
isinstance
(
obj
,
dict
)
and
'__type__'
in
obj
and
'arguments'
in
obj
:
return
import_
(
obj
[
'__type__'
])(
**
obj
[
'arguments'
])
return
obj
def
serialize
(
cls
,
*
args
,
**
kwargs
):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
def
_type_encode
(
obj
,
primitives
=
False
):
assert
not
primitives
,
'Encoding with primitives is not supported.'
if
isinstance
(
obj
,
type
):
return
{
'__typename__'
:
get_importable_name
(
obj
,
relocate_module
=
True
)}
if
isinstance
(
obj
,
(
types
.
FunctionType
,
types
.
BuiltinFunctionType
)):
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return
{
'__typename__'
:
get_importable_name
(
obj
,
relocate_module
=
True
)}
return
obj
self.op = serialize(MyCustomOp, hidden_units=128)
"""
warnings
.
warn
(
'nni.retiarii.serialize is deprecated and will be removed in future release. '
+
'Try to use nni.trace, e.g., nni.trace(torch.optim.Adam)(learning_rate=1e-4) instead.'
,
category
=
DeprecationWarning
)
return
trace
(
cls
)(
*
args
,
**
kwargs
)
def
_type_decode
(
obj
):
if
isinstance
(
obj
,
dict
)
and
'__typename__'
in
obj
:
return
import_
(
obj
[
'__typename__'
])
return
obj
def
serialize_cls
(
cls
):
"""
To create an serializable class.
"""
warnings
.
warn
(
'nni.retiarii.serialize is deprecated and will be removed in future release. '
+
'Try to use nni.trace instead.'
,
category
=
DeprecationWarning
)
return
trace
(
cls
)
json_loads
=
functools
.
partial
(
json_tricks
.
loads
,
extra_obj_pairs_hooks
=
[
_serialize_class_instance_decode
,
_type_decode
])
json_dumps
=
functools
.
partial
(
json_tricks
.
dumps
,
extra_obj_encoders
=
[
_serialize_class_instance_encode
,
_type_encode
])
json_load
=
functools
.
partial
(
json_tricks
.
load
,
extra_obj_pairs_hooks
=
[
_serialize_class_instance_decode
,
_type_decode
])
json_dump
=
functools
.
partial
(
json_tricks
.
dump
,
extra_obj_encoders
=
[
_serialize_class_instance_encode
,
_type_encode
])
def
basic_unit
(
cls
:
T
,
basic_unit_tag
:
bool
=
True
)
->
Union
[
T
,
Traceable
]:
"""
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
### End of json-tricks patch ###
``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.
.. code-block:: python
class
Translatable
(
abc
.
ABC
):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
@basic_unit
class PrimitiveOp(nn.Module):
...
"""
_check_wrapped
(
cls
)
@
abc
.
abstractmethod
def
_translate
(
self
)
->
Any
:
pass
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
cls
=
trace
(
cls
)
cls
.
_nni_basic_unit
=
basic_unit_tag
def
_create_wrapper_cls
(
cls
,
store_init_parameters
=
True
,
reset_mutation_uid
=
False
,
stop_parsing
=
True
):
class
wrapper
(
cls
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_stop_parsing
=
stop_parsing
if
reset_mutation_uid
:
reset_uid
(
'mutation'
)
if
store_init_parameters
:
argname_list
=
list
(
inspect
.
signature
(
cls
.
__init__
).
parameters
.
keys
())[
1
:]
full_args
=
{}
full_args
.
update
(
kwargs
)
assert
len
(
args
)
<=
len
(
argname_list
),
f
'Length of
{
args
}
is greater than length of
{
argname_list
}
.'
for
argname
,
value
in
zip
(
argname_list
,
args
):
full_args
[
argname
]
=
value
# translate parameters
args
=
list
(
args
)
for
i
,
value
in
enumerate
(
args
):
if
isinstance
(
value
,
Translatable
):
args
[
i
]
=
value
.
_translate
()
for
i
,
value
in
kwargs
.
items
():
if
isinstance
(
value
,
Translatable
):
kwargs
[
i
]
=
value
.
_translate
()
self
.
_init_parameters
=
full_args
else
:
self
.
_init_parameters
=
{}
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import
torch
cls
.
_get_nni_attr
=
torch
.
jit
.
ignore
(
cls
.
_get_nni_attr
)
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
return
cls
wrapper
.
__module__
=
get_module_name
(
cls
)
wrapper
.
__name__
=
cls
.
__name__
wrapper
.
__qualname__
=
cls
.
__qualname__
wrapper
.
__init__
.
__doc__
=
cls
.
__init__
.
__doc__
return
wrapper
def
model_wrapper
(
cls
:
T
)
->
Union
[
T
,
Traceable
]:
"""
Wrap the model if you are using pure-python execution engine. For example
.. code-block:: python
def
serialize_cls
(
cls
):
"""
To create an serializable class.
"""
return
_create_wrapper_cls
(
cls
)
@model_wrapper
class MyModel(nn.Module):
...
The wrapper serves two purposes:
def
transparent_serialize
(
cls
):
"""
Wrap a module but does not record parameters. For internal use only
.
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in ``mutation`` namespace so that each model counts from zero.
Can be useful in unittest and other multi-model scenarios
.
"""
return
_create_wrapper_cls
(
cls
,
store_init_parameters
=
False
)
_check_wrapped
(
cls
)
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
)
def
serialize
(
cls
,
*
args
,
**
kwargs
):
"""
To create an serializable instance inline without decorator. For example,
wrapper
=
trace
(
cls
)
.. code-block:: python
class
reset_wrapper
(
wrapper
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
with
ModelNamespace
():
super
().
__init__
(
*
args
,
**
kwargs
)
self.op = serialize(MyCustomOp, hidden_units=128)
"""
return
serialize_cls
(
cls
)(
*
args
,
**
kwargs
)
_copy_class_wrapper_attributes
(
wrapper
,
reset_wrapper
)
reset_wrapper
.
__wrapped__
=
wrapper
.
__wrapped__
reset_wrapper
.
_nni_model_wrapper
=
True
return
reset_wrapper
def
basic_unit
(
cls
):
"""
To wrap a module as a basic unit, to stop it from parsing and make it mutate-able.
"""
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
return
serialize_cls
(
cls
)
def
is_basic_unit
(
cls_or_instance
)
->
bool
:
if
not
inspect
.
isclass
(
cls_or_instance
):
cls_or_instance
=
cls_or_instance
.
__class__
return
getattr
(
cls_or_instance
,
'_nni_basic_unit'
,
False
)
def
model_wrapper
(
cls
):
"""
Wrap the model if you are using pure-python execution engine.
def
is_model_wrapped
(
cls_or_instance
)
->
bool
:
if
not
inspect
.
isclass
(
cls_or_instance
):
cls_or_instance
=
cls_or_instance
.
__class__
return
getattr
(
cls_or_instance
,
'_nni_model_wrapper'
,
False
)
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
"""
return
_create_wrapper_cls
(
cls
,
reset_mutation_uid
=
True
,
stop_parsing
=
False
)
def
_check_wrapped
(
cls
:
T
)
->
bool
:
if
getattr
(
cls
,
'_traced'
,
False
)
or
getattr
(
cls
,
'_nni_model_wrapper'
,
False
):
raise
TypeError
(
f
'
{
cls
}
is already wrapped with trace wrapper (basic_unit / model_wrapper / trace). Cannot wrap again.'
)
nni/retiarii/utils.py
View file @
443ba8c1
...
...
@@ -25,6 +25,8 @@ def version_larger_equal(a: str, b: str) -> bool:
_last_uid
=
defaultdict
(
int
)
_DEFAULT_MODEL_NAMESPACE
=
'model'
def
uid
(
namespace
:
str
=
'default'
)
->
int
:
_last_uid
[
namespace
]
+=
1
...
...
@@ -77,6 +79,8 @@ class ContextStack:
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace.
Note that this is not multi-processing safe. Also, the values will get cleared for a new process.
"""
_stack
:
Dict
[
str
,
List
[
Any
]]
=
defaultdict
(
list
)
...
...
@@ -107,5 +111,46 @@ class ContextStack:
return
cls
.
_stack
[
key
][
-
1
]
class
ModelNamespace
:
"""
To create an individual namespace for models to enable automatic numbering.
"""
def
__init__
(
self
,
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
):
# for example, key: "model_wrapper"
self
.
key
=
key
def
__enter__
(
self
):
# For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
# the next thing up is [1, 2, 2, 4].
# `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
try
:
current_context
=
ContextStack
.
top
(
self
.
key
)
next_uid
=
uid
(
self
.
_simple_name
(
self
.
key
,
current_context
))
ContextStack
.
push
(
self
.
key
,
current_context
+
[
next_uid
])
reset_uid
(
self
.
_simple_name
(
self
.
key
,
current_context
+
[
next_uid
]))
except
NoContextError
:
ContextStack
.
push
(
self
.
key
,
[])
reset_uid
(
self
.
_simple_name
(
self
.
key
,
[]))
def
__exit__
(
self
,
*
args
,
**
kwargs
):
ContextStack
.
pop
(
self
.
key
)
@
staticmethod
def
next_label
(
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
)
->
str
:
try
:
current_context
=
ContextStack
.
top
(
key
)
except
NoContextError
:
# fallback to use "default" namespace
return
ModelNamespace
.
_simple_name
(
'default'
,
[
uid
()])
next_uid
=
uid
(
ModelNamespace
.
_simple_name
(
key
,
current_context
))
return
ModelNamespace
.
_simple_name
(
key
,
current_context
+
[
next_uid
])
@
staticmethod
def
_simple_name
(
key
:
str
,
lst
:
List
[
Any
])
->
str
:
return
key
+
''
.
join
([
'_'
+
str
(
k
)
for
k
in
lst
])
def
get_current_context
(
key
:
str
)
->
Any
:
return
ContextStack
.
top
(
key
)
nni/runtime/msg_dispatcher.py
View file @
443ba8c1
...
...
@@ -3,7 +3,6 @@
import
logging
from
collections
import
defaultdict
import
json_tricks
from
nni
import
NoMoreTrialError
from
nni.assessor
import
AssessResult
...
...
@@ -12,7 +11,8 @@ from .common import multi_thread_enabled, multi_phase_enabled
from
.env_vars
import
dispatcher_env_vars
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.protocol
import
CommandType
,
send
from
..utils
import
MetricType
,
to_json
from
..common.serializer
import
dump
,
load
from
..utils
import
MetricType
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -63,7 +63,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
ret
[
'parameter_index'
]
=
parameter_index
else
:
ret
[
'parameter_index'
]
=
0
return
to_json
(
ret
)
return
dump
(
ret
)
class
MsgDispatcher
(
MsgDispatcherBase
):
...
...
@@ -115,8 +115,8 @@ class MsgDispatcher(MsgDispatcherBase):
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for
entry
in
data
:
entry
[
'value'
]
=
entry
[
'value'
]
if
type
(
entry
[
'value'
])
is
str
else
json_tricks
.
dump
s
(
entry
[
'value'
])
entry
[
'value'
]
=
json_tricks
.
load
s
(
entry
[
'value'
])
entry
[
'value'
]
=
entry
[
'value'
]
if
type
(
entry
[
'value'
])
is
str
else
dump
(
entry
[
'value'
])
entry
[
'value'
]
=
load
(
entry
[
'value'
])
self
.
tuner
.
import_data
(
data
)
def
handle_add_customized_trial
(
self
,
data
):
...
...
@@ -133,7 +133,7 @@ class MsgDispatcher(MsgDispatcherBase):
"""
# metrics value is dumped as json string in trial, so we need to decode it here
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
load
s
(
data
[
'value'
])
data
[
'value'
]
=
load
(
data
[
'value'
])
if
data
[
'type'
]
==
MetricType
.
FINAL
:
self
.
_handle_final_metric_data
(
data
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
...
...
@@ -167,7 +167,7 @@ class MsgDispatcher(MsgDispatcherBase):
if
self
.
assessor
is
not
None
:
self
.
assessor
.
trial_end
(
trial_job_id
,
data
[
'event'
]
==
'SUCCEEDED'
)
if
self
.
tuner
is
not
None
:
self
.
tuner
.
trial_end
(
json_tricks
.
load
s
(
data
[
'hyper_params'
])[
'parameter_id'
],
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
tuner
.
trial_end
(
load
(
data
[
'hyper_params'
])[
'parameter_id'
],
data
[
'event'
]
==
'SUCCEEDED'
)
def
_handle_final_metric_data
(
self
,
data
):
"""Call tuner to process final results
...
...
@@ -221,7 +221,7 @@ class MsgDispatcher(MsgDispatcherBase):
if
result
is
AssessResult
.
Bad
:
_logger
.
debug
(
'BAD, kill %s'
,
trial_job_id
)
send
(
CommandType
.
KillTrialJob
,
json_tricks
.
dump
s
(
trial_job_id
))
send
(
CommandType
.
KillTrialJob
,
dump
(
trial_job_id
))
# notify tuner
_logger
.
debug
(
'env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]'
,
dispatcher_env_vars
.
NNI_INCLUDE_INTERMEDIATE_RESULTS
)
...
...
@@ -239,5 +239,5 @@ class MsgDispatcher(MsgDispatcherBase):
if
multi_thread_enabled
():
self
.
_handle_final_metric_data
(
data
)
else
:
data
[
'value'
]
=
to_json
(
data
[
'value'
])
data
[
'value'
]
=
dump
(
data
[
'value'
])
self
.
enqueue_command
(
CommandType
.
ReportMetricData
,
data
)
nni/runtime/msg_dispatcher_base.py
View file @
443ba8c1
...
...
@@ -5,10 +5,10 @@ import threading
import
logging
from
multiprocessing.dummy
import
Pool
as
ThreadPool
from
queue
import
Queue
,
Empty
import
json_tricks
from
.common
import
multi_thread_enabled
from
.env_vars
import
dispatcher_env_vars
from
..common
import
load
from
..recoverable
import
Recoverable
from
.protocol
import
CommandType
,
receive
...
...
@@ -50,7 +50,7 @@ class MsgDispatcherBase(Recoverable):
while
not
self
.
stopping
:
command
,
data
=
receive
()
if
data
:
data
=
json_tricks
.
load
s
(
data
)
data
=
load
(
data
)
if
command
is
None
or
command
is
CommandType
.
Terminate
:
break
...
...
@@ -162,7 +162,7 @@ class MsgDispatcherBase(Recoverable):
def
handle_request_trial_jobs
(
self
,
data
):
"""The message dispatcher is demanded to generate ``data`` trial jobs.
These trial jobs should be sent via ``send(CommandType.NewTrialJob,
json_tricks
.dump
s
(parameter))``,
These trial jobs should be sent via ``send(CommandType.NewTrialJob,
nni
.dump(parameter))``,
where ``parameter`` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this ``send`` exactly ``data`` times.
...
...
nni/runtime/platform/local.py
View file @
443ba8c1
...
...
@@ -3,11 +3,10 @@
import
os
import
sys
import
json
import
time
import
subprocess
from
nni.
utils
import
to_json
from
nni.
common
import
dump
,
load
from
..env_vars
import
trial_env_vars
_sysdir
=
trial_env_vars
.
NNI_SYS_DIR
...
...
@@ -27,7 +26,7 @@ _multiphase = trial_env_vars.MULTI_PHASE
_param_index
=
0
def
request_next_parameter
():
metric
=
to_json
({
metric
=
dump
({
'trial_job_id'
:
trial_env_vars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'REQUEST_PARAMETER'
,
'sequence'
:
0
,
...
...
@@ -54,7 +53,7 @@ def get_next_parameter():
while
not
(
os
.
path
.
isfile
(
params_filepath
)
and
os
.
path
.
getsize
(
params_filepath
)
>
0
):
time
.
sleep
(
3
)
params_file
=
open
(
params_filepath
,
'r'
)
params
=
json
.
load
(
params_file
)
params
=
load
(
fp
=
params_file
)
_param_index
+=
1
return
params
...
...
nni/runtime/platform/standalone.py
View file @
443ba8c1
...
...
@@ -5,7 +5,8 @@ import logging
import
warnings
import
colorama
import
json_tricks
from
nni.common
import
load
__all__
=
[
'get_next_parameter'
,
...
...
@@ -44,7 +45,7 @@ def get_sequence_id():
return
0
def
send_metric
(
string
):
metric
=
json_tricks
.
load
s
(
string
)
metric
=
load
(
string
)
if
metric
[
'type'
]
==
'FINAL'
:
_logger
.
info
(
'Final result: %s'
,
metric
[
'value'
])
elif
metric
[
'type'
]
==
'PERIODICAL'
:
...
...
nni/runtime/platform/test.py
View file @
443ba8c1
...
...
@@ -4,7 +4,7 @@
# pylint: skip-file
import
copy
import
json_tricks
from
nni.common
import
load
_params
=
None
...
...
@@ -14,15 +14,19 @@ _last_metric = None
def
get_next_parameter
():
return
_params
def
get_experiment_id
():
return
'fakeidex'
def
get_trial_id
():
return
'fakeidtr'
def
get_sequence_id
():
return
0
def
send_metric
(
string
):
global
_last_metric
_last_metric
=
string
...
...
@@ -32,8 +36,9 @@ def init_params(params):
global
_params
_params
=
copy
.
deepcopy
(
params
)
def
get_last_metric
():
metrics
=
json_tricks
.
load
s
(
_last_metric
)
metrics
[
'value'
]
=
json_tricks
.
load
s
(
metrics
[
'value'
])
metrics
=
load
(
_last_metric
)
metrics
[
'value'
]
=
load
(
metrics
[
'value'
])
return
metrics
nni/tools/nnictl/config_utils.py
View file @
443ba8c1
...
...
@@ -3,7 +3,7 @@
import
os
import
sqlite3
import
json_tricks
import
nni
from
.constants
import
NNI_HOME_DIR
from
.common_utils
import
get_file_lock
...
...
@@ -95,7 +95,7 @@ class Config:
'''refresh to get latest config'''
sql
=
'select params from ExperimentProfile where id=? order by revision DESC'
args
=
(
self
.
experiment_id
,)
self
.
config
=
config_v0_to_v1
(
json_tricks
.
load
s
(
self
.
conn
.
cursor
().
execute
(
sql
,
args
).
fetchone
()[
0
]))
self
.
config
=
config_v0_to_v1
(
nni
.
load
(
self
.
conn
.
cursor
().
execute
(
sql
,
args
).
fetchone
()[
0
]))
def
get_config
(
self
):
'''get a value according to key'''
...
...
@@ -159,7 +159,7 @@ class Experiments:
'''save config to local file'''
try
:
with
open
(
self
.
experiment_file
,
'w'
)
as
file
:
json_tricks
.
dump
(
self
.
experiments
,
file
,
indent
=
4
)
nni
.
dump
(
self
.
experiments
,
file
,
indent
=
4
)
except
IOError
as
error
:
print
(
'Error:'
,
error
)
return
''
...
...
@@ -169,7 +169,7 @@ class Experiments:
if
os
.
path
.
exists
(
self
.
experiment_file
):
try
:
with
open
(
self
.
experiment_file
,
'r'
)
as
file
:
return
json_tricks
.
load
(
file
)
return
nni
.
load
(
fp
=
file
)
except
ValueError
:
return
{}
return
{}
nni/trial.py
View file @
443ba8c1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.
utils
import
to_json
from
.
common.serializer
import
dump
from
.runtime.env_vars
import
trial_env_vars
from
.runtime
import
platform
...
...
@@ -124,12 +124,12 @@ def report_intermediate_result(metric):
global
_intermediate_seq
assert
_params
or
trial_env_vars
.
NNI_PLATFORM
is
None
,
\
'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric
=
to_json
({
metric
=
dump
({
'parameter_id'
:
_params
[
'parameter_id'
]
if
_params
else
None
,
'trial_job_id'
:
trial_env_vars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'PERIODICAL'
,
'sequence'
:
_intermediate_seq
,
'value'
:
to_json
(
metric
)
'value'
:
dump
(
metric
)
})
_intermediate_seq
+=
1
platform
.
send_metric
(
metric
)
...
...
@@ -146,11 +146,11 @@ def report_final_result(metric):
"""
assert
_params
or
trial_env_vars
.
NNI_PLATFORM
is
None
,
\
'nni.get_next_parameter() needs to be called before report_final_result'
metric
=
to_json
({
metric
=
dump
({
'parameter_id'
:
_params
[
'parameter_id'
]
if
_params
else
None
,
'trial_job_id'
:
trial_env_vars
.
NNI_TRIAL_JOB_ID
,
'type'
:
'FINAL'
,
'sequence'
:
0
,
'value'
:
to_json
(
metric
)
'value'
:
dump
(
metric
)
})
platform
.
send_metric
(
metric
)
nni/utils.py
View file @
443ba8c1
...
...
@@ -2,17 +2,13 @@
# Licensed under the MIT license.
import
copy
import
functools
from
enum
import
Enum
,
unique
from
pathlib
import
Path
import
json_tricks
from
schema
import
And
from
.
import
parameter_expressions
to_json
=
functools
.
partial
(
json_tricks
.
dumps
,
allow_nan
=
True
)
@
unique
class
OptimizeMode
(
Enum
):
"""Optimize Mode class
...
...
test/ut/retiarii/inject_nn.py
View file @
443ba8c1
import
inspect
import
logging
import
torch
import
torch.nn
as
nn
from
nni.retiarii
.utils
import
version_larger_equal
from
nni.retiarii
import
basic_unit
_logger
=
logging
.
getLogger
(
__name__
)
_trace_module_names
=
[
module_name
for
module_name
in
dir
(
nn
)
if
module_name
not
in
[
'Module'
,
'ModuleList'
,
'ModuleDict'
,
'Sequential'
]
and
inspect
.
isclass
(
getattr
(
nn
,
module_name
))
and
issubclass
(
getattr
(
nn
,
module_name
),
nn
.
Module
)
]
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
original_class
.
bak_init_for_inject
=
orig_init
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
self
.
_init_parameters
=
full_args
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
def
unwrap_module
(
wrapped_class
):
if
hasattr
(
wrapped_class
,
'bak_init_for_inject'
):
wrapped_class
.
__init__
=
wrapped_class
.
bak_init_for_inject
delattr
(
wrapped_class
,
'bak_init_for_inject'
)
return
None
def
remove_inject_pytorch_nn
():
Identity
=
unwrap_module
(
nn
.
Identity
)
Linear
=
unwrap_module
(
nn
.
Linear
)
Conv1d
=
unwrap_module
(
nn
.
Conv1d
)
Conv2d
=
unwrap_module
(
nn
.
Conv2d
)
Conv3d
=
unwrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
unwrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
unwrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
unwrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
unwrap_module
(
nn
.
Threshold
)
ReLU
=
unwrap_module
(
nn
.
ReLU
)
Hardtanh
=
unwrap_module
(
nn
.
Hardtanh
)
ReLU6
=
unwrap_module
(
nn
.
ReLU6
)
Sigmoid
=
unwrap_module
(
nn
.
Sigmoid
)
Tanh
=
unwrap_module
(
nn
.
Tanh
)
Softmax
=
unwrap_module
(
nn
.
Softmax
)
Softmax2d
=
unwrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
unwrap_module
(
nn
.
LogSoftmax
)
ELU
=
unwrap_module
(
nn
.
ELU
)
SELU
=
unwrap_module
(
nn
.
SELU
)
CELU
=
unwrap_module
(
nn
.
CELU
)
GLU
=
unwrap_module
(
nn
.
GLU
)
GELU
=
unwrap_module
(
nn
.
GELU
)
Hardshrink
=
unwrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
unwrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
unwrap_module
(
nn
.
LogSigmoid
)
Softplus
=
unwrap_module
(
nn
.
Softplus
)
Softshrink
=
unwrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
unwrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
unwrap_module
(
nn
.
PReLU
)
Softsign
=
unwrap_module
(
nn
.
Softsign
)
Softmin
=
unwrap_module
(
nn
.
Softmin
)
Tanhshrink
=
unwrap_module
(
nn
.
Tanhshrink
)
RReLU
=
unwrap_module
(
nn
.
RReLU
)
AvgPool1d
=
unwrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
unwrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
unwrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
unwrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
unwrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
unwrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
unwrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
unwrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
unwrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
unwrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
unwrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
unwrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
unwrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
unwrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
unwrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
unwrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
unwrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
unwrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
unwrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
unwrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
unwrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
unwrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
unwrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
unwrap_module
(
nn
.
Dropout
)
Dropout2d
=
unwrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
unwrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
unwrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
unwrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
unwrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
unwrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
unwrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
unwrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
unwrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
unwrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
unwrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
unwrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
unwrap_module
(
nn
.
RNNBase
)
RNN
=
unwrap_module
(
nn
.
RNN
)
LSTM
=
unwrap_module
(
nn
.
LSTM
)
GRU
=
unwrap_module
(
nn
.
GRU
)
RNNCellBase
=
unwrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
unwrap_module
(
nn
.
RNNCell
)
LSTMCell
=
unwrap_module
(
nn
.
LSTMCell
)
GRUCell
=
unwrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
unwrap_module
(
nn
.
PixelShuffle
)
Upsample
=
unwrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
unwrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
unwrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
unwrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
unwrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
unwrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
unwrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
unwrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
unwrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
unwrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
unwrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
unwrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
unwrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
unwrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
unwrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
unwrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
unwrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
unwrap_module
(
nn
.
Unfold
)
Fold
=
unwrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
unwrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
unwrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
unwrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
unwrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
unwrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
unwrap_module
(
nn
.
Transformer
)
Flatten
=
unwrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
unwrap_module
(
nn
.
Hardsigmoid
)
for
name
in
_trace_module_names
:
if
hasattr
(
getattr
(
nn
,
name
),
'__wrapped__'
):
setattr
(
nn
,
name
,
getattr
(
nn
,
name
).
__wrapped__
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
unwrap_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
unwrap_module
(
nn
.
SiLU
)
Unflatten
=
unwrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
unwrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
def
inject_pytorch_nn
():
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
Conv3d
=
wrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
wrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
wrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
wrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
wrap_module
(
nn
.
Threshold
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
Hardtanh
=
wrap_module
(
nn
.
Hardtanh
)
ReLU6
=
wrap_module
(
nn
.
ReLU6
)
Sigmoid
=
wrap_module
(
nn
.
Sigmoid
)
Tanh
=
wrap_module
(
nn
.
Tanh
)
Softmax
=
wrap_module
(
nn
.
Softmax
)
Softmax2d
=
wrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
wrap_module
(
nn
.
LogSoftmax
)
ELU
=
wrap_module
(
nn
.
ELU
)
SELU
=
wrap_module
(
nn
.
SELU
)
CELU
=
wrap_module
(
nn
.
CELU
)
GLU
=
wrap_module
(
nn
.
GLU
)
GELU
=
wrap_module
(
nn
.
GELU
)
Hardshrink
=
wrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
wrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
wrap_module
(
nn
.
LogSigmoid
)
Softplus
=
wrap_module
(
nn
.
Softplus
)
Softshrink
=
wrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
wrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
wrap_module
(
nn
.
PReLU
)
Softsign
=
wrap_module
(
nn
.
Softsign
)
Softmin
=
wrap_module
(
nn
.
Softmin
)
Tanhshrink
=
wrap_module
(
nn
.
Tanhshrink
)
RReLU
=
wrap_module
(
nn
.
RReLU
)
AvgPool1d
=
wrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
wrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
wrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
wrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
wrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
wrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
wrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
wrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
wrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
wrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
wrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
wrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
wrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
wrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
wrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
wrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
wrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
wrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
wrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
wrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Dropout2d
=
wrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
wrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
wrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
wrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
wrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
wrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
wrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
wrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
wrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
wrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
wrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
wrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
wrap_module
(
nn
.
RNNBase
)
RNN
=
wrap_module
(
nn
.
RNN
)
LSTM
=
wrap_module
(
nn
.
LSTM
)
GRU
=
wrap_module
(
nn
.
GRU
)
RNNCellBase
=
wrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
wrap_module
(
nn
.
RNNCell
)
LSTMCell
=
wrap_module
(
nn
.
LSTMCell
)
GRUCell
=
wrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
wrap_module
(
nn
.
PixelShuffle
)
Upsample
=
wrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
wrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
wrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
wrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
wrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
wrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
wrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
wrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
wrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
wrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
wrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
wrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
wrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
wrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
wrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
wrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
wrap_module
(
nn
.
Unfold
)
Fold
=
wrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
wrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
wrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
wrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
Flatten
=
wrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
wrap_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
wrap_module
(
nn
.
SiLU
)
Unflatten
=
wrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
wrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
for
name
in
_trace_module_names
:
setattr
(
nn
,
name
,
basic_unit
(
getattr
(
nn
,
name
)))
test/ut/retiarii/test_cgo_engine.py
View file @
443ba8c1
import
json
import
os
import
threading
import
unittest
...
...
@@ -161,7 +160,7 @@ def _new_trainer():
def
_load_mnist
(
n_models
:
int
=
1
):
path
=
Path
(
__file__
).
parent
/
'mnist_pytorch.json'
with
open
(
path
)
as
f
:
mnist_model
=
Model
.
_load
(
json
.
load
(
f
))
mnist_model
=
Model
.
_load
(
nni
.
load
(
fp
=
f
))
mnist_model
.
evaluator
=
_new_trainer
()
if
n_models
==
1
:
...
...
@@ -176,12 +175,12 @@ def _load_mnist(n_models: int = 1):
def
_get_final_result
():
result
=
json
.
load
s
(
nni
.
runtime
.
platform
.
test
.
_last_metric
)[
'value'
]
result
=
nni
.
load
(
nni
.
runtime
.
platform
.
test
.
_last_metric
)[
'value'
]
if
isinstance
(
result
,
list
):
return
[
float
(
_
)
for
_
in
result
]
else
:
if
isinstance
(
result
,
str
)
and
'['
in
result
:
return
json
.
load
s
(
result
)
return
nni
.
load
(
result
)
return
[
float
(
result
)]
...
...
@@ -311,7 +310,7 @@ class CGOEngineTest(unittest.TestCase):
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
>=
2
:
cmd
,
data
=
protocol
.
receive
()
params
=
json
.
load
s
(
data
)
params
=
nni
.
load
(
data
)
tt
.
init_params
(
params
)
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
443ba8c1
...
...
@@ -50,7 +50,7 @@ class FCNet(nn.Module):
return
output
.
view
(
-
1
)
@
serialize_cls
@
nni
.
trace
class
DiabetesDataset
(
Dataset
):
def
__init__
(
self
,
train
=
True
):
data
=
load_diabetes
()
...
...
test/ut/retiarii/test_namespace.py
0 → 100644
View file @
443ba8c1
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
model_wrapper
@
model_wrapper
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
10
,
3
)
self
.
conv2
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
10
,
10
,
3
),
nn
.
MaxPool2d
(
3
)
])
self
.
conv3
=
nn
.
LayerChoice
([
nn
.
Identity
(),
nn
.
Conv2d
(
10
,
10
,
1
)
])
self
.
avgpool
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
fc
=
nn
.
Linear
(
10
,
1
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
conv3
(
x
)
x
=
self
.
avgpool
(
x
).
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
fc
(
x
)
return
x
@
model_wrapper
class
ModelInner
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
net1
=
nn
.
LayerChoice
([
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
,
bias
=
False
)
])
self
.
net2
=
nn
.
LayerChoice
([
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
,
bias
=
False
)
])
def
forward
(
self
,
x
):
x
=
self
.
net1
(
x
)
x
=
self
.
net2
(
x
)
return
x
@
model_wrapper
class
ModelNested
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
fc1
=
ModelInner
()
self
.
fc2
=
nn
.
LayerChoice
([
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
,
bias
=
False
)
])
self
.
fc3
=
ModelInner
()
def
forward
(
self
,
x
):
return
self
.
fc3
(
self
.
fc2
(
self
.
fc1
(
x
)))
def
test_model_wrapper
():
model
=
Model
(
3
)
assert
model
.
trace_symbol
==
Model
.
__wrapped__
assert
model
.
trace_kwargs
==
{
'in_channels'
:
3
}
assert
model
.
conv2
.
label
==
'model_1'
assert
model
.
conv3
.
label
==
'model_2'
assert
model
(
torch
.
randn
(
1
,
3
,
5
,
5
)).
size
()
==
torch
.
Size
([
1
,
1
])
model
=
Model
(
4
)
assert
model
.
trace_symbol
==
Model
.
__wrapped__
assert
model
.
conv2
.
label
==
'model_1'
# not changed
def
test_model_wrapper_nested
():
model
=
ModelNested
()
assert
model
.
fc1
.
net1
.
label
==
'model_1_1'
assert
model
.
fc1
.
net2
.
label
==
'model_1_2'
assert
model
.
fc2
.
label
==
'model_2'
assert
model
.
fc3
.
net1
.
label
==
'model_3_1'
assert
model
.
fc3
.
net2
.
label
==
'model_3_2'
if
__name__
==
'__main__'
:
test_model_wrapper_nested
()
test/ut/retiarii/test_serializer.py
deleted
100644 → 0
View file @
896c516f
import
json
import
math
from
pathlib
import
Path
import
re
import
sys
import
torch
from
nni.retiarii
import
json_dumps
,
json_loads
,
serialize
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
sys
.
path
.
insert
(
0
,
Path
(
__file__
).
parent
.
as_posix
())
from
imported.model
import
ImportTest
class
Foo
:
def
__init__
(
self
,
a
,
b
=
1
):
self
.
aa
=
a
self
.
bb
=
[
b
+
1
for
_
in
range
(
1000
)]
def
__eq__
(
self
,
other
):
return
self
.
aa
==
other
.
aa
and
self
.
bb
==
other
.
bb
def
test_serialize
():
module
=
serialize
(
Foo
,
3
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
serialize
(
Foo
,
b
=
2
,
a
=
1
)
assert
json_loads
(
json_dumps
(
module
))
==
module
module
=
serialize
(
Foo
,
Foo
(
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
module
=
serialize
(
Foo
,
serialize
(
Foo
,
1
),
5
)
dumped_module
=
json_dumps
(
module
)
assert
len
(
dumped_module
)
<
200
# should not be too longer if the serialization is correct
assert
json_loads
(
dumped_module
)
==
module
def
test_basic_unit
():
module
=
ImportTest
(
3
,
0.5
)
assert
json_loads
(
json_dumps
(
module
))
==
module
def
test_dataset
():
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
dumped_ans
=
{
"__type__"
:
"torch.utils.data.dataloader.DataLoader"
,
"arguments"
:
{
"batch_size"
:
10
,
"dataset"
:
{
"__type__"
:
"torchvision.datasets.mnist.MNIST"
,
"arguments"
:
{
"root"
:
"data/mnist"
,
"train"
:
False
,
"download"
:
True
}
}
}
}
assert
json_dumps
(
dataloader
)
==
json_dumps
(
dumped_ans
)
dataloader
=
json_loads
(
json_dumps
(
dumped_ans
))
assert
isinstance
(
dataloader
,
DataLoader
)
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
serialize
(
transforms
.
Compose
,
[
serialize
(
transforms
.
ToTensor
),
serialize
(
transforms
.
Normalize
,
(
0.1307
,),
(
0.3081
,))]
))
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
dataset
=
serialize
(
MNIST
,
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))]))
dataloader
=
serialize
(
DataLoader
,
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
json_loads
(
json_dumps
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
def
test_type
():
assert
json_dumps
(
torch
.
optim
.
Adam
)
==
'{"__typename__": "torch.optim.adam.Adam"}'
assert
json_loads
(
'{"__typename__": "torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
re
.
match
(
r
'{"__typename__": "(.*)test_serializer.Foo"}'
,
json_dumps
(
Foo
))
assert
json_dumps
(
math
.
floor
)
==
'{"__typename__": "math.floor"}'
assert
json_loads
(
'{"__typename__": "math.floor"}'
)
==
math
.
floor
if
__name__
==
'__main__'
:
test_serialize
()
test_basic_unit
()
test_dataset
()
test_type
()
test/ut/
retiarii
/imported/model.py
→
test/ut/
sdk
/imported/model.py
View file @
443ba8c1
File moved
test/ut/sdk/test_serializer.py
View file @
443ba8c1
import
math
from
pathlib
import
Path
import
re
import
sys
import
nni
import
torch
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
from
nni.common.serializer
import
is_traceable
if
True
:
# prevent auto formatting
sys
.
path
.
insert
(
0
,
Path
(
__file__
).
parent
.
as_posix
())
from
imported.model
import
ImportTest
@
nni
.
trace
...
...
@@ -23,8 +37,8 @@ def test_simple_class():
assert
'"__kwargs__": {"a": 1, "b": 2}'
in
dump_str
assert
'"__symbol__"'
in
dump_str
instance
=
nni
.
load
(
dump_str
)
assert
instance
.
get
().
_a
==
1
assert
instance
.
get
().
_b
==
2
assert
instance
.
_a
==
1
assert
instance
.
_b
==
2
def
test_external_class
():
...
...
@@ -44,7 +58,7 @@ def test_external_class():
r
'"__kwargs__": {"in_channels": 3, "out_channels": 16, "kernel_size": 3}}'
conv
=
nni
.
load
(
nni
.
dump
(
conv
))
assert
conv
.
get
().
kernel_size
==
(
3
,
3
)
assert
conv
.
kernel_size
==
(
3
,
3
)
def
test_nested_class
():
...
...
@@ -53,8 +67,8 @@ def test_nested_class():
assert
b
.
_a
.
_a
==
1
dump_str
=
nni
.
dump
(
b
)
b
=
nni
.
load
(
dump_str
)
assert
repr
(
b
)
==
'SerializableObject(type=SimpleClass, a=SerializableObject(type=SimpleClass, a=1, b=2))'
assert
b
.
get
().
_a
.
_a
==
1
assert
'SimpleClass object at'
in
repr
(
b
)
assert
b
.
_a
.
_a
==
1
def
test_unserializable
():
...
...
@@ -64,8 +78,137 @@ def test_unserializable():
assert
a
.
_a
==
1
def
test_function
():
t
=
nni
.
trace
(
math
.
sqrt
,
kw_only
=
False
)(
3
)
assert
1
<
t
<
2
assert
t
.
trace_symbol
==
math
.
sqrt
assert
t
.
trace_args
==
[
3
]
t
=
nni
.
load
(
nni
.
dump
(
t
))
assert
1
<
t
<
2
assert
not
is_traceable
(
t
)
# trace not recovered, expected, limitation
def
simple_class_factory
(
bb
=
3.
):
return
SimpleClass
(
1
,
bb
)
t
=
nni
.
trace
(
simple_class_factory
)(
4
)
ts
=
nni
.
dump
(
t
)
assert
'__kwargs__'
in
ts
t
=
nni
.
load
(
ts
)
assert
t
.
_a
==
1
assert
is_traceable
(
t
)
t
=
t
.
trace_copy
()
assert
is_traceable
(
t
)
assert
t
.
trace_symbol
(
10
).
_b
==
10
assert
t
.
trace_kwargs
[
'bb'
]
==
4
assert
is_traceable
(
t
.
trace_copy
())
class
Foo
:
def
__init__
(
self
,
a
,
b
=
1
):
self
.
aa
=
a
self
.
bb
=
[
b
+
1
for
_
in
range
(
1000
)]
def
__eq__
(
self
,
other
):
return
self
.
aa
==
other
.
aa
and
self
.
bb
==
other
.
bb
def
test_custom_class
():
module
=
nni
.
trace
(
Foo
)(
3
)
assert
nni
.
load
(
nni
.
dump
(
module
))
==
module
module
=
nni
.
trace
(
Foo
)(
b
=
2
,
a
=
1
)
assert
nni
.
load
(
nni
.
dump
(
module
))
==
module
module
=
nni
.
trace
(
Foo
)(
Foo
(
1
),
5
)
dumped_module
=
nni
.
dump
(
module
)
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
module
=
nni
.
trace
(
Foo
)(
nni
.
trace
(
Foo
)(
1
),
5
)
dumped_module
=
nni
.
dump
(
module
)
assert
nni
.
load
(
dumped_module
)
==
module
class
Foo
:
def
__init__
(
self
,
a
,
b
=
1
):
self
.
aa
=
a
self
.
bb
=
[
b
+
1
for
_
in
range
(
1000
)]
def
__eq__
(
self
,
other
):
return
self
.
aa
==
other
.
aa
and
self
.
bb
==
other
.
bb
def
test_basic_unit_and_custom_import
():
module
=
ImportTest
(
3
,
0.5
)
ss
=
nni
.
dump
(
module
)
assert
ss
==
r
'{"__symbol__": "path:imported.model.ImportTest", "__kwargs__": {"foo": 3, "bar": 0.5}}'
assert
nni
.
load
(
nni
.
dump
(
module
))
==
module
import
nni.retiarii.nn.pytorch
as
nn
module
=
nn
.
Conv2d
(
3
,
10
,
3
,
bias
=
False
)
ss
=
nni
.
dump
(
module
)
assert
ss
==
r
'{"__symbol__": "path:torch.nn.modules.conv.Conv2d", "__kwargs__": {"in_channels": 3, "out_channels": 10, "kernel_size": 3, "bias": false}}'
assert
nni
.
load
(
ss
).
bias
is
None
def
test_dataset
():
dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
)
dataloader
=
nni
.
trace
(
DataLoader
)(
dataset
,
batch_size
=
10
)
dumped_ans
=
{
"__symbol__"
:
"path:torch.utils.data.dataloader.DataLoader"
,
"__kwargs__"
:
{
"dataset"
:
{
"__symbol__"
:
"path:torchvision.datasets.mnist.MNIST"
,
"__kwargs__"
:
{
"root"
:
"data/mnist"
,
"train"
:
False
,
"download"
:
True
}
},
"batch_size"
:
10
}
}
print
(
nni
.
dump
(
dataloader
))
print
(
nni
.
dump
(
dumped_ans
))
assert
nni
.
dump
(
dataloader
)
==
nni
.
dump
(
dumped_ans
)
dataloader
=
nni
.
load
(
nni
.
dump
(
dumped_ans
))
assert
isinstance
(
dataloader
,
DataLoader
)
dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
nni
.
trace
(
transforms
.
Compose
)([
nni
.
trace
(
transforms
.
ToTensor
)(),
nni
.
trace
(
transforms
.
Normalize
)((
0.1307
,),
(
0.3081
,))
]))
dataloader
=
nni
.
trace
(
DataLoader
)(
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
nni
.
load
(
nni
.
dump
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
nni
.
trace
(
transforms
.
Compose
)(
[
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))]
))
dataloader
=
nni
.
trace
(
DataLoader
)(
dataset
,
batch_size
=
10
)
x
,
y
=
next
(
iter
(
nni
.
load
(
nni
.
dump
(
dataloader
))))
assert
x
.
size
()
==
torch
.
Size
([
10
,
1
,
28
,
28
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
def
test_type
():
assert
nni
.
dump
(
torch
.
optim
.
Adam
)
==
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
assert
nni
.
load
(
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
Foo
==
nni
.
load
(
nni
.
dump
(
Foo
))
assert
nni
.
dump
(
math
.
floor
)
==
'{"__nni_type__": "path:math.floor"}'
assert
nni
.
load
(
'{"__nni_type__": "path:math.floor"}'
)
==
math
.
floor
def
test_lightning_earlystop
():
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
from
pytorch_lightning.callbacks.early_stopping
import
EarlyStopping
trainer
=
pl
.
Trainer
(
callbacks
=
[
nni
.
trace
(
EarlyStopping
)(
monitor
=
"val_loss"
)])
trainer
=
nni
.
load
(
nni
.
dump
(
trainer
))
assert
any
(
isinstance
(
callback
,
EarlyStopping
)
for
callback
in
trainer
.
callbacks
)
if
__name__
==
'__main__'
:
test_simple_class
()
test_external_class
()
test_nested_class
()
test_unserializable
()
# test_simple_class()
# test_external_class()
# test_nested_class()
# test_unserializable()
# test_basic_unit()
test_type
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment