Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
21abc280
"vscode:/vscode.git/clone" did not exist on "562e1e27673bea4d9ce6793d418c7788138e49ed"
Unverified
Commit
21abc280
authored
Mar 03, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 03, 2022
Browse files
Fix #4434: support pickle in serializer (#4552)
parent
c447249c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
307 additions
and
62 deletions
+307
-62
nni/algorithms/compression/v2/pytorch/utils/constructor_helper.py
...rithms/compression/v2/pytorch/utils/constructor_helper.py
+6
-6
nni/common/serializer.py
nni/common/serializer.py
+174
-29
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
+1
-0
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+42
-17
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+1
-0
test/ut/sdk/test_serializer.py
test/ut/sdk/test_serializer.py
+83
-10
No files found.
nni/algorithms/compression/v2/pytorch/utils/constructor_helper.py
View file @
21abc280
...
@@ -10,7 +10,7 @@ from torch.optim import Optimizer
...
@@ -10,7 +10,7 @@ from torch.optim import Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.optim.lr_scheduler
import
_LRScheduler
from
nni.common.serializer
import
_trace_cls
from
nni.common.serializer
import
_trace_cls
from
nni.common.serializer
import
Traceable
from
nni.common.serializer
import
Traceable
,
is_traceable
__all__
=
[
'OptimizerConstructHelper'
,
'LRSchedulerConstructHelper'
]
__all__
=
[
'OptimizerConstructHelper'
,
'LRSchedulerConstructHelper'
]
...
@@ -80,14 +80,14 @@ class OptimizerConstructHelper(ConstructHelper):
...
@@ -80,14 +80,14 @@ class OptimizerConstructHelper(ConstructHelper):
@
staticmethod
@
staticmethod
def
from_trace
(
model
:
Module
,
optimizer_trace
:
Traceable
):
def
from_trace
(
model
:
Module
,
optimizer_trace
:
Traceable
):
assert
is
instanc
e
(
optimizer_trace
,
Traceable
),
\
assert
is
_traceabl
e
(
optimizer_trace
),
\
'Please use nni.trace to wrap the optimizer class before initialize the optimizer.'
'Please use nni.trace to wrap the optimizer class before initialize the optimizer.'
assert
isinstance
(
optimizer_trace
,
Optimizer
),
\
assert
isinstance
(
optimizer_trace
,
Optimizer
),
\
'It is not an instance of torch.nn.Optimizer.'
'It is not an instance of torch.nn.Optimizer.'
return
OptimizerConstructHelper
(
model
,
return
OptimizerConstructHelper
(
model
,
optimizer_trace
.
_get_nni_attr
(
'
symbol
'
)
,
optimizer_trace
.
trace_
symbol
,
*
optimizer_trace
.
_get_nni_attr
(
'
args
'
)
,
*
optimizer_trace
.
trace_
args
,
**
optimizer_trace
.
_get_nni_attr
(
'
kwargs
'
)
)
**
optimizer_trace
.
trace_
kwargs
)
class
LRSchedulerConstructHelper
(
ConstructHelper
):
class
LRSchedulerConstructHelper
(
ConstructHelper
):
...
@@ -112,7 +112,7 @@ class LRSchedulerConstructHelper(ConstructHelper):
...
@@ -112,7 +112,7 @@ class LRSchedulerConstructHelper(ConstructHelper):
@
staticmethod
@
staticmethod
def
from_trace
(
lr_scheduler_trace
:
Traceable
):
def
from_trace
(
lr_scheduler_trace
:
Traceable
):
assert
is
instanc
e
(
lr_scheduler_trace
,
Traceable
),
\
assert
is
_traceabl
e
(
lr_scheduler_trace
),
\
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
assert
isinstance
(
lr_scheduler_trace
,
_LRScheduler
),
\
assert
isinstance
(
lr_scheduler_trace
,
_LRScheduler
),
\
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
...
...
nni/common/serializer.py
View file @
21abc280
...
@@ -5,6 +5,7 @@ import copy
...
@@ -5,6 +5,7 @@ import copy
import
functools
import
functools
import
inspect
import
inspect
import
numbers
import
numbers
import
sys
import
types
import
types
import
warnings
import
warnings
from
io
import
IOBase
from
io
import
IOBase
...
@@ -13,7 +14,7 @@ from typing import Any, Dict, List, Optional, TypeVar, Union
...
@@ -13,7 +14,7 @@ from typing import Any, Dict, List, Optional, TypeVar, Union
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
import
json_tricks
# use json_tricks as serializer backend
import
json_tricks
# use json_tricks as serializer backend
__all__
=
[
'trace'
,
'dump'
,
'load'
,
'PayloadTooLarge'
,
'Translatable'
,
'Traceable'
,
'is_traceable'
]
__all__
=
[
'trace'
,
'dump'
,
'load'
,
'PayloadTooLarge'
,
'Translatable'
,
'Traceable'
,
'is_traceable'
,
'is_wrapped_with_trace'
]
T
=
TypeVar
(
'T'
)
T
=
TypeVar
(
'T'
)
...
@@ -23,46 +24,43 @@ class PayloadTooLarge(Exception):
...
@@ -23,46 +24,43 @@ class PayloadTooLarge(Exception):
pass
pass
class
Traceable
(
abc
.
ABC
)
:
class
Traceable
:
"""
"""
A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
Dict returns a TraceDictType to enable serialization.
Dict returns a TraceDictType to enable serialization.
"""
"""
@
abc
.
abstractmethod
def
trace_copy
(
self
)
->
'Traceable'
:
def
trace_copy
(
self
)
->
'Traceable'
:
"""
"""
Perform a shallow copy.
Perform a shallow copy.
NOTE: NONE of the attributes will be preserved.
NOTE: NONE of the attributes will be preserved.
This is the one that should be used when you want to "mutate" a serializable object.
This is the one that should be used when you want to "mutate" a serializable object.
"""
"""
...
raise
NotImplementedError
()
@
property
@
property
@
abc
.
abstractmethod
def
trace_symbol
(
self
)
->
Any
:
def
trace_symbol
(
self
)
->
Any
:
"""
"""
Symbol object. Could be a class or a function.
Symbol object. Could be a class or a function.
``get_hybrid_cls_or_func_name`` and ``import_cls_or_func_from_hybrid_name`` is a pair to
``get_hybrid_cls_or_func_name`` and ``import_cls_or_func_from_hybrid_name`` is a pair to
convert the symbol into a string and convert the string back to symbol.
convert the symbol into a string and convert the string back to symbol.
"""
"""
...
raise
NotImplementedError
()
@
property
@
property
@
abc
.
abstractmethod
def
trace_args
(
self
)
->
List
[
Any
]:
def
trace_args
(
self
)
->
List
[
Any
]:
"""
"""
List of positional arguments passed to symbol. Usually empty if ``kw_only`` is true,
List of positional arguments passed to symbol. Usually empty if ``kw_only`` is true,
in which case all the positional arguments are converted into keyword arguments.
in which case all the positional arguments are converted into keyword arguments.
"""
"""
...
raise
NotImplementedError
()
@
property
@
property
@
abc
.
abstractmethod
def
trace_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
def
trace_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
"""
"""
Dict of keyword arguments.
Dict of keyword arguments.
"""
"""
...
raise
NotImplementedError
()
class
Translatable
(
abc
.
ABC
):
class
Translatable
(
abc
.
ABC
):
...
@@ -84,13 +82,27 @@ class Translatable(abc.ABC):
...
@@ -84,13 +82,27 @@ class Translatable(abc.ABC):
def
is_traceable
(
obj
:
Any
)
->
bool
:
def
is_traceable
(
obj
:
Any
)
->
bool
:
"""
"""
Check whether an object is a traceable instance (not type).
Check whether an object is a traceable instance or type.
Note that an object is traceable only means that it implements the "Traceable" interface,
and the properties have been implemented. It doesn't necessary mean that its type is wrapped with trace,
because the properties could be added **after** the instance has been created.
"""
"""
return
hasattr
(
obj
,
'trace_copy'
)
and
\
return
hasattr
(
obj
,
'trace_copy'
)
and
\
hasattr
(
obj
,
'trace_symbol'
)
and
\
hasattr
(
obj
,
'trace_symbol'
)
and
\
hasattr
(
obj
,
'trace_args'
)
and
\
hasattr
(
obj
,
'trace_args'
)
and
\
hasattr
(
obj
,
'trace_kwargs'
)
and
\
hasattr
(
obj
,
'trace_kwargs'
)
not
inspect
.
isclass
(
obj
)
def
is_wrapped_with_trace
(
cls_or_func
:
Any
)
->
bool
:
"""
Check whether a function or class is already wrapped with ``@nni.trace``.
If a class or function is already wrapped with trace, then the created object must be "traceable".
"""
return
getattr
(
cls_or_func
,
'_traced'
,
False
)
and
(
not
hasattr
(
cls_or_func
,
'__dict__'
)
or
# in case it's a function
'_traced'
in
cls_or_func
.
__dict__
# must be in this class, super-class traced doesn't count
)
class
SerializableObject
(
Traceable
):
class
SerializableObject
(
Traceable
):
...
@@ -160,6 +172,15 @@ class SerializableObject(Traceable):
...
@@ -160,6 +172,15 @@ class SerializableObject(Traceable):
def
inject_trace_info
(
obj
:
Any
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
def
inject_trace_info
(
obj
:
Any
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
# Make obj complying with the interface of traceable, though we cannot change its base class.
obj
.
__dict__
.
update
(
_nni_symbol
=
symbol
,
_nni_args
=
args
,
_nni_kwargs
=
kwargs
)
return
obj
def
_make_class_traceable
(
cls
:
T
,
create_wrapper
:
bool
=
False
)
->
T
:
# Make an already exist class traceable, without creating a new class.
# Should be used together with `inject_trace_info`.
def
getter_factory
(
x
):
def
getter_factory
(
x
):
return
lambda
self
:
self
.
__dict__
[
'_nni_'
+
x
]
return
lambda
self
:
self
.
__dict__
[
'_nni_'
+
x
]
...
@@ -184,20 +205,18 @@ def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, An
...
@@ -184,20 +205,18 @@ def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, An
'trace_copy'
:
trace_copy
'trace_copy'
:
trace_copy
}
}
if
hasattr
(
obj
,
'__class__'
)
and
hasattr
(
obj
,
'__dict__'
)
:
if
not
create_wrapper
:
for
name
,
method
in
attributes
.
items
():
for
name
,
method
in
attributes
.
items
():
setattr
(
obj
.
__class__
,
name
,
method
)
setattr
(
cls
,
name
,
method
)
return
cls
else
:
else
:
wrapper
=
type
(
'wrapper'
,
(
Traceable
,
type
(
obj
)),
attributes
)
# sometimes create_wrapper is mandatory, e.g., for built-in types like list/int.
obj
=
wrapper
(
obj
)
# pylint: disable=abstract-class-instantiated
# but I don't want to check here because it's unreliable.
wrapper
=
type
(
'wrapper'
,
(
Traceable
,
cls
),
attributes
)
# make obj complying with the interface of traceable, though we cannot change its base class
return
wrapper
obj
.
__dict__
.
update
(
_nni_symbol
=
symbol
,
_nni_args
=
args
,
_nni_kwargs
=
kwargs
)
return
obj
def
trace
(
cls_or_func
:
T
=
None
,
*
,
kw_only
:
bool
=
True
)
->
Union
[
T
,
Traceable
]:
def
trace
(
cls_or_func
:
T
=
None
,
*
,
kw_only
:
bool
=
True
,
inheritable
:
bool
=
False
)
->
Union
[
T
,
Traceable
]:
"""
"""
Annotate a function or a class if you want to preserve where it comes from.
Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios:
This is usually used in the following scenarios:
...
@@ -221,6 +240,9 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
...
@@ -221,6 +240,9 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
Therefore, in some cases, some positional arguments will still be kept.
Therefore, in some cases, some positional arguments will still be kept.
If ``inheritable`` is true, the trace information from superclass will also be available in subclass.
This however, will make the subclass un-trace-able. Note that this argument has no effect when tracing functions.
.. warning::
.. warning::
Generators will be first expanded into a list, and the resulting list will be further passed into the wrapped function/class.
Generators will be first expanded into a list, and the resulting list will be further passed into the wrapped function/class.
...
@@ -237,10 +259,10 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
...
@@ -237,10 +259,10 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
def
wrap
(
cls_or_func
):
def
wrap
(
cls_or_func
):
# already annotated, do nothing
# already annotated, do nothing
if
getattr
(
cls_or_func
,
'_traced'
,
False
):
if
is_wrapped_with_trace
(
cls_or_func
):
return
cls_or_func
return
cls_or_func
if
isinstance
(
cls_or_func
,
type
):
if
isinstance
(
cls_or_func
,
type
):
cls_or_func
=
_trace_cls
(
cls_or_func
,
kw_only
)
cls_or_func
=
_trace_cls
(
cls_or_func
,
kw_only
,
inheritable
=
inheritable
)
elif
_is_function
(
cls_or_func
):
elif
_is_function
(
cls_or_func
):
cls_or_func
=
_trace_func
(
cls_or_func
,
kw_only
)
cls_or_func
=
_trace_func
(
cls_or_func
,
kw_only
)
else
:
else
:
...
@@ -353,11 +375,60 @@ def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comme
...
@@ -353,11 +375,60 @@ def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comme
return
json_tricks
.
load
(
fp
,
obj_pairs_hooks
=
hooks
,
**
json_tricks_kwargs
)
return
json_tricks
.
load
(
fp
,
obj_pairs_hooks
=
hooks
,
**
json_tricks_kwargs
)
def
_trace_cls
(
base
,
kw_only
,
call_super
=
True
):
def
_trace_cls
(
base
,
kw_only
,
call_super
=
True
,
inheritable
=
False
):
# the implementation to trace a class is to store a copy of init arguments
# the implementation to trace a class is to store a copy of init arguments
# this won't support class that defines a customized new but should work for most cases
# this won't support class that defines a customized new but should work for most cases
class
wrapper
(
SerializableObject
,
base
):
if
sys
.
platform
!=
'linux'
:
if
not
call_super
:
raise
ValueError
(
"'call_super' is mandatory to be set true on non-linux platform"
)
try
:
# In non-linux envs, dynamically creating new classes doesn't work with pickle.
# We have to replace the ``__init__`` with a new ``__init__``.
# This, however, causes side-effects where the replacement is not intended.
# This also doesn't work built-in types (e.g., OrderedDict), and the replacement
# won't be effective any more if ``nni.trace`` is called in-place (e.g., ``nni.trace(nn.Conv2d)(...)``).
original_init
=
base
.
__init__
# Makes the new init have the exact same signature as the old one,
# so as to make pytorch-lightning happy.
# https://github.com/PyTorchLightning/pytorch-lightning/blob/4cc05b2cf98e49168a5f5dc265647d75d1d3aae9/pytorch_lightning/utilities/parsing.py#L143
@
functools
.
wraps
(
original_init
)
def
new_init
(
self
,
*
args
,
**
kwargs
):
args
,
kwargs
=
_formulate_arguments
(
original_init
,
args
,
kwargs
,
kw_only
,
is_class_init
=
True
)
original_init
(
self
,
*
[
_argument_processor
(
arg
)
for
arg
in
args
],
**
{
kw
:
_argument_processor
(
arg
)
for
kw
,
arg
in
kwargs
.
items
()}
)
inject_trace_info
(
self
,
base
,
args
,
kwargs
)
base
.
__init__
=
new_init
base
=
_make_class_traceable
(
base
)
return
base
except
TypeError
:
warnings
.
warn
(
"In-place __init__ replacement failed in `@nni.trace`, probably because the type is a built-in/extension type, "
"and it's __init__ can't be replaced. `@nni.trace` is now falling back to the 'inheritance' approach. "
"However, this could cause issues when using pickle. See https://github.com/microsoft/nni/issues/4434"
,
RuntimeWarning
)
# This is trying to solve the case where superclass and subclass are both decorated with @nni.trace.
# We use a metaclass to "unwrap" the superclass.
# However, this doesn't work if:
# 1. Base class already has a customized metaclass. We will raise error in that class.
# 2. SerializableObject in ancester (instead of parent). I think this case is rare and I didn't handle this case yet. FIXME
if
type
(
base
)
is
type
and
not
inheritable
:
metaclass
=
_unwrap_metaclass
else
:
metaclass
=
type
if
SerializableObject
in
inspect
.
getmro
(
base
):
raise
TypeError
(
f
"
{
base
}
has a superclass already decorated with trace, and it's using a customized metaclass
{
type
(
base
)
}
. "
"Please either use the default metaclass, or remove trace from the super-class."
)
class
wrapper
(
SerializableObject
,
base
,
metaclass
=
metaclass
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# store a copy of initial parameters
# store a copy of initial parameters
args
,
kwargs
=
_formulate_arguments
(
base
.
__init__
,
args
,
kwargs
,
kw_only
,
is_class_init
=
True
)
args
,
kwargs
=
_formulate_arguments
(
base
.
__init__
,
args
,
kwargs
,
kw_only
,
is_class_init
=
True
)
...
@@ -365,6 +436,32 @@ def _trace_cls(base, kw_only, call_super=True):
...
@@ -365,6 +436,32 @@ def _trace_cls(base, kw_only, call_super=True):
# calling serializable object init to initialize the full object
# calling serializable object init to initialize the full object
super
().
__init__
(
symbol
=
base
,
args
=
args
,
kwargs
=
kwargs
,
call_super
=
call_super
)
super
().
__init__
(
symbol
=
base
,
args
=
args
,
kwargs
=
kwargs
,
call_super
=
call_super
)
def
__reduce__
(
self
):
# The issue that decorator and pickler doesn't play well together is well known.
# The workaround solution is to use a fool class (_pickling_object) which pretends to be the pickled object.
# We then put the original type, as well as args and kwargs in its `__new__` argument.
# I suspect that their could still be problems when things get complex,
# e.g., the wrapped class has a custom pickling (`__reduce__``) or `__new__`.
# But it can't be worse because the previous pickle doesn't work at all.
#
# Linked issue: https://github.com/microsoft/nni/issues/4434
# SO: https://stackoverflow.com/questions/52185507/pickle-and-decorated-classes-picklingerror-not-the-same-object
# Store the inner class. The wrapped class couldn't be properly pickled.
type_
=
cloudpickle
.
dumps
(
type
(
self
).
__wrapped__
)
# in case they have customized ``__getstate__``.
if
hasattr
(
self
,
'__getstate__'
):
obj_
=
self
.
__getstate__
()
else
:
obj_
=
self
.
__dict__
# Pickle can't handle type objects.
if
'_nni_symbol'
in
obj_
:
obj_
[
'_nni_symbol'
]
=
cloudpickle
.
dumps
(
obj_
[
'_nni_symbol'
])
return
_pickling_object
,
(
type_
,
kw_only
,
obj_
)
_copy_class_wrapper_attributes
(
base
,
wrapper
)
_copy_class_wrapper_attributes
(
base
,
wrapper
)
return
wrapper
return
wrapper
...
@@ -391,6 +488,8 @@ def _trace_func(func, kw_only):
...
@@ -391,6 +488,8 @@ def _trace_func(func, kw_only):
elif
hasattr
(
res
,
'__class__'
)
and
hasattr
(
res
,
'__dict__'
):
elif
hasattr
(
res
,
'__class__'
)
and
hasattr
(
res
,
'__dict__'
):
# is a class, inject interface directly
# is a class, inject interface directly
# need to be done before primitive types because there could be inheritance here.
# need to be done before primitive types because there could be inheritance here.
if
not
getattr
(
type
(
res
),
'_traced'
,
False
):
_make_class_traceable
(
type
(
res
),
False
)
# in-place
res
=
inject_trace_info
(
res
,
func
,
args
,
kwargs
)
res
=
inject_trace_info
(
res
,
func
,
args
,
kwargs
)
elif
isinstance
(
res
,
(
collections
.
abc
.
Callable
,
types
.
ModuleType
,
IOBase
)):
elif
isinstance
(
res
,
(
collections
.
abc
.
Callable
,
types
.
ModuleType
,
IOBase
)):
raise
TypeError
(
f
'Try to add trace info to
{
res
}
, but functions and modules are not supported.'
)
raise
TypeError
(
f
'Try to add trace info to
{
res
}
, but functions and modules are not supported.'
)
...
@@ -400,6 +499,8 @@ def _trace_func(func, kw_only):
...
@@ -400,6 +499,8 @@ def _trace_func(func, kw_only):
# will be directly captured by python json encoder
# will be directly captured by python json encoder
# and thus not possible to restore the trace parameters after dump and reload.
# and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation.
# this is a known limitation.
new_type
=
_make_class_traceable
(
type
(
res
),
True
)
res
=
new_type
(
res
)
# re-creating the object
res
=
inject_trace_info
(
res
,
func
,
args
,
kwargs
)
res
=
inject_trace_info
(
res
,
func
,
args
,
kwargs
)
else
:
else
:
raise
TypeError
(
f
'Try to add trace info to
{
res
}
, but the type "
{
type
(
res
)
}
" is unknown. '
raise
TypeError
(
f
'Try to add trace info to
{
res
}
, but the type "
{
type
(
res
)
}
" is unknown. '
...
@@ -425,6 +526,48 @@ def _copy_class_wrapper_attributes(base, wrapper):
...
@@ -425,6 +526,48 @@ def _copy_class_wrapper_attributes(base, wrapper):
wrapper
.
__wrapped__
=
base
wrapper
.
__wrapped__
=
base
class
_unwrap_metaclass
(
type
):
# When a subclass is created, it detects whether the super-class is already annotated with @nni.trace.
# If yes, it gets the ``__wrapped__`` inner class, so that it doesn't inherit SerializableObject twice.
# Note that this doesn't work when metaclass is already defined (such as ABCMeta). We give up in that case.
def
__new__
(
cls
,
name
,
bases
,
dct
):
bases
=
tuple
([
getattr
(
base
,
'__wrapped__'
,
base
)
for
base
in
bases
])
return
super
().
__new__
(
cls
,
name
,
bases
,
dct
)
# Using a customized "bases" breaks default isinstance and issubclass.
# We recover this by overriding the subclass and isinstance behavior, which conerns wrapped class only.
def
__subclasscheck__
(
cls
,
subclass
):
inner_cls
=
getattr
(
cls
,
'__wrapped__'
,
cls
)
return
inner_cls
in
inspect
.
getmro
(
subclass
)
def
__instancecheck__
(
cls
,
instance
):
inner_cls
=
getattr
(
cls
,
'__wrapped__'
,
cls
)
return
inner_cls
in
inspect
.
getmro
(
type
(
instance
))
class
_pickling_object
:
# Need `cloudpickle.load` on the callable because the callable is pickled with cloudpickle.
# Used in `_trace_cls`.
def
__new__
(
cls
,
type_
,
kw_only
,
data
):
type_
=
cloudpickle
.
loads
(
type_
)
# Restore the trace type
type_
=
_trace_cls
(
type_
,
kw_only
)
# restore type
if
'_nni_symbol'
in
data
:
data
[
'_nni_symbol'
]
=
cloudpickle
.
loads
(
data
[
'_nni_symbol'
])
# https://docs.python.org/3/library/pickle.html#pickling-class-instances
obj
=
type_
.
__new__
(
type_
)
if
hasattr
(
obj
,
'__setstate__'
):
obj
.
__setstate__
(
data
)
else
:
obj
.
__dict__
.
update
(
data
)
return
obj
def
_argument_processor
(
arg
):
def
_argument_processor
(
arg
):
# 1) translate
# 1) translate
# handle cases like ValueChoice
# handle cases like ValueChoice
...
@@ -533,7 +676,9 @@ def _import_cls_or_func_from_name(target: str) -> Any:
...
@@ -533,7 +676,9 @@ def _import_cls_or_func_from_name(target: str) -> Any:
def
_strip_trace_type
(
traceable
:
Any
)
->
Any
:
def
_strip_trace_type
(
traceable
:
Any
)
->
Any
:
if
getattr
(
traceable
,
'_traced'
,
False
):
if
getattr
(
traceable
,
'_traced'
,
False
):
return
traceable
.
__wrapped__
# sometimes, ``__wrapped__`` could be unavailable (e.g., with `inject_trace_info`)
# need to have a default value
return
getattr
(
traceable
,
'__wrapped__'
,
traceable
)
return
traceable
return
traceable
...
@@ -598,7 +743,7 @@ def _json_tricks_serializable_object_encode(obj: Any, primitives: bool = False,
...
@@ -598,7 +743,7 @@ def _json_tricks_serializable_object_encode(obj: Any, primitives: bool = False,
# Encodes a serializable object instance to json.
# Encodes a serializable object instance to json.
# do nothing to instance that is not a serializable object and do not use trace
# do nothing to instance that is not a serializable object and do not use trace
if
not
use_trace
or
not
is_traceable
(
obj
):
if
not
(
use_trace
and
hasattr
(
obj
,
'__class__'
)
and
is_traceable
(
type
(
obj
)
))
:
return
obj
return
obj
if
isinstance
(
obj
.
trace_symbol
,
property
):
if
isinstance
(
obj
.
trace_symbol
,
property
):
...
...
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
View file @
21abc280
...
@@ -101,6 +101,7 @@ class _MultiModelSupervisedLearningModule(LightningModule):
...
@@ -101,6 +101,7 @@ class _MultiModelSupervisedLearningModule(LightningModule):
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
@
nni
.
trace
class
MultiModelSupervisedLearningModule
(
_MultiModelSupervisedLearningModule
):
class
MultiModelSupervisedLearningModule
(
_MultiModelSupervisedLearningModule
):
"""
"""
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
...
...
nni/retiarii/serializer.py
View file @
21abc280
...
@@ -5,7 +5,7 @@ import inspect
...
@@ -5,7 +5,7 @@ import inspect
import
warnings
import
warnings
from
typing
import
Any
,
TypeVar
,
Union
from
typing
import
Any
,
TypeVar
,
Union
from
nni.common.serializer
import
Traceable
,
is_traceable
,
trace
,
_copy_class_wrapper_attributes
from
nni.common.serializer
import
Traceable
,
is_traceable
,
is_wrapped_with_trace
,
trace
,
_copy_class_wrapper_attributes
from
.utils
import
ModelNamespace
from
.utils
import
ModelNamespace
__all__
=
[
'get_init_parameters_or_fail'
,
'serialize'
,
'serialize_cls'
,
'basic_unit'
,
'model_wrapper'
,
__all__
=
[
'get_init_parameters_or_fail'
,
'serialize'
,
'serialize_cls'
,
'basic_unit'
,
'model_wrapper'
,
...
@@ -64,7 +64,8 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
...
@@ -64,7 +64,8 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
class PrimitiveOp(nn.Module):
class PrimitiveOp(nn.Module):
...
...
"""
"""
_check_wrapped
(
cls
)
if
_check_wrapped
(
cls
,
'basic_unit'
):
return
cls
import
torch.nn
as
nn
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
...
@@ -72,15 +73,7 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
...
@@ -72,15 +73,7 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
cls
=
trace
(
cls
)
cls
=
trace
(
cls
)
cls
.
_nni_basic_unit
=
basic_unit_tag
cls
.
_nni_basic_unit
=
basic_unit_tag
# HACK: for torch script
_torchscript_patch
(
cls
)
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import
torch
cls
.
_get_nni_attr
=
torch
.
jit
.
ignore
(
cls
.
_get_nni_attr
)
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
return
cls
return
cls
...
@@ -103,12 +96,14 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
...
@@ -103,12 +96,14 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
But in future, we might enforce ``@model_wrapper`` to be required for base model.
But in future, we might enforce ``@model_wrapper`` to be required for base model.
"""
"""
_check_wrapped
(
cls
)
if
_check_wrapped
(
cls
,
'model_wrapper'
):
return
cls
import
torch.nn
as
nn
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
)
assert
issubclass
(
cls
,
nn
.
Module
)
wrapper
=
trace
(
cls
)
# subclass can still use trace info
wrapper
=
trace
(
cls
,
inheritable
=
True
)
class
reset_wrapper
(
wrapper
):
class
reset_wrapper
(
wrapper
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
@@ -116,8 +111,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
...
@@ -116,8 +111,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
_copy_class_wrapper_attributes
(
wrapper
,
reset_wrapper
)
_copy_class_wrapper_attributes
(
wrapper
,
reset_wrapper
)
reset_wrapper
.
__wrapped__
=
wrapper
.
__wrapped__
reset_wrapper
.
__wrapped__
=
getattr
(
wrapper
,
'
__wrapped__
'
,
wrapper
)
reset_wrapper
.
_nni_model_wrapper
=
True
reset_wrapper
.
_nni_model_wrapper
=
True
reset_wrapper
.
_traced
=
True
_torchscript_patch
(
cls
)
return
reset_wrapper
return
reset_wrapper
...
@@ -133,6 +132,32 @@ def is_model_wrapped(cls_or_instance) -> bool:
...
@@ -133,6 +132,32 @@ def is_model_wrapped(cls_or_instance) -> bool:
return
getattr
(
cls_or_instance
,
'_nni_model_wrapper'
,
False
)
return
getattr
(
cls_or_instance
,
'_nni_model_wrapper'
,
False
)
def
_check_wrapped
(
cls
:
T
)
->
bool
:
def
_check_wrapped
(
cls
:
T
,
rewrap
:
str
)
->
bool
:
if
getattr
(
cls
,
'_traced'
,
False
)
or
getattr
(
cls
,
'_nni_model_wrapper'
,
False
):
wrapped
=
None
raise
TypeError
(
f
'
{
cls
}
is already wrapped with trace wrapper (basic_unit / model_wrapper / trace). Cannot wrap again.'
)
if
is_model_wrapped
(
cls
):
wrapped
=
'model_wrapper'
elif
is_basic_unit
(
cls
):
wrapped
=
'basic_unit'
elif
is_wrapped_with_trace
(
cls
):
wrapped
=
'nni.trace'
if
wrapped
:
if
wrapped
!=
rewrap
:
raise
TypeError
(
f
'
{
cls
}
is already wrapped with
{
wrapped
}
. Cannot rewrap with
{
rewrap
}
.'
)
return
True
return
False
def
_torchscript_patch
(
cls
)
->
None
:
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import
torch
if
hasattr
(
cls
,
'_get_nni_attr'
):
# could not exist on non-linux
cls
.
_get_nni_attr
=
torch
.
jit
.
ignore
(
cls
.
_get_nni_attr
)
if
hasattr
(
cls
,
'trace_symbol'
):
# these must all exist or all non-exist
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
cls
.
trace_copy
=
torch
.
jit
.
ignore
(
cls
.
trace_copy
)
test/ut/retiarii/test_cgo_engine.py
View file @
21abc280
...
@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
...
@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from
pathlib
import
Path
from
pathlib
import
Path
import
nni
import
nni
import
nni.runtime.platform.test
try
:
try
:
from
nni.common.device
import
GPUDevice
from
nni.common.device
import
GPUDevice
...
...
test/ut/sdk/test_serializer.py
View file @
21abc280
import
math
import
math
import
pickle
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -27,6 +28,11 @@ class SimpleClass:
...
@@ -27,6 +28,11 @@ class SimpleClass:
self
.
_b
=
b
self
.
_b
=
b
@
nni
.
trace
class
EmptyClass
:
pass
class
UnserializableSimpleClass
:
class
UnserializableSimpleClass
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_a
=
1
self
.
_a
=
1
...
@@ -124,7 +130,8 @@ def test_custom_class():
...
@@ -124,7 +130,8 @@ def test_custom_class():
module
=
nni
.
trace
(
Foo
)(
Foo
(
1
),
5
)
module
=
nni
.
trace
(
Foo
)(
Foo
(
1
),
5
)
dumped_module
=
nni
.
dump
(
module
)
dumped_module
=
nni
.
dump
(
module
)
assert
len
(
dumped_module
)
>
200
# should not be too longer if the serialization is correct
module
=
nni
.
load
(
dumped_module
)
assert
module
.
bb
[
0
]
==
module
.
bb
[
999
]
==
6
module
=
nni
.
trace
(
Foo
)(
nni
.
trace
(
Foo
)(
1
),
5
)
module
=
nni
.
trace
(
Foo
)(
nni
.
trace
(
Foo
)(
1
),
5
)
dumped_module
=
nni
.
dump
(
module
)
dumped_module
=
nni
.
dump
(
module
)
...
@@ -193,6 +200,20 @@ def test_dataset():
...
@@ -193,6 +200,20 @@ def test_dataset():
assert
y
.
size
()
==
torch
.
Size
([
10
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
def
test_pickle
():
pickle
.
dumps
(
EmptyClass
())
obj
=
SimpleClass
(
1
)
obj
=
pickle
.
loads
(
pickle
.
dumps
(
obj
))
assert
obj
.
_a
==
1
assert
obj
.
_b
==
1
obj
=
SimpleClass
(
1
)
obj
.
xxx
=
3
obj
=
pickle
.
loads
(
pickle
.
dumps
(
obj
))
assert
obj
.
xxx
==
3
@
pytest
.
mark
.
skipif
(
sys
.
platform
!=
'linux'
,
reason
=
'https://github.com/microsoft/nni/issues/4434'
)
@
pytest
.
mark
.
skipif
(
sys
.
platform
!=
'linux'
,
reason
=
'https://github.com/microsoft/nni/issues/4434'
)
def
test_multiprocessing_dataloader
():
def
test_multiprocessing_dataloader
():
# check whether multi-processing works
# check whether multi-processing works
...
@@ -208,6 +229,28 @@ def test_multiprocessing_dataloader():
...
@@ -208,6 +229,28 @@ def test_multiprocessing_dataloader():
assert
y
.
size
()
==
torch
.
Size
([
10
])
assert
y
.
size
()
==
torch
.
Size
([
10
])
def
_test_multiprocessing_dataset_worker
(
dataset
):
if
sys
.
platform
==
'linux'
:
# on non-linux, the loaded object will become non-traceable
# due to an implementation limitation
assert
is_traceable
(
dataset
)
else
:
from
torch.utils.data
import
Dataset
assert
isinstance
(
dataset
,
Dataset
)
def
test_multiprocessing_dataset
():
from
torch.utils.data
import
Dataset
dataset
=
nni
.
trace
(
Dataset
)()
import
multiprocessing
process
=
multiprocessing
.
Process
(
target
=
_test_multiprocessing_dataset_worker
,
args
=
(
dataset
,
))
process
.
start
()
process
.
join
()
assert
process
.
exitcode
==
0
def
test_type
():
def
test_type
():
assert
nni
.
dump
(
torch
.
optim
.
Adam
)
==
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
assert
nni
.
dump
(
torch
.
optim
.
Adam
)
==
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
assert
nni
.
load
(
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
nni
.
load
(
'{"__nni_type__": "path:torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
...
@@ -220,10 +263,20 @@ def test_lightning_earlystop():
...
@@ -220,10 +263,20 @@ def test_lightning_earlystop():
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
from
pytorch_lightning.callbacks.early_stopping
import
EarlyStopping
from
pytorch_lightning.callbacks.early_stopping
import
EarlyStopping
trainer
=
pl
.
Trainer
(
callbacks
=
[
nni
.
trace
(
EarlyStopping
)(
monitor
=
"val_loss"
)])
trainer
=
pl
.
Trainer
(
callbacks
=
[
nni
.
trace
(
EarlyStopping
)(
monitor
=
"val_loss"
)])
trainer
=
nni
.
load
(
nni
.
dump
(
trainer
))
pickle_size_limit
=
4096
if
sys
.
platform
==
'linux'
else
32768
trainer
=
nni
.
load
(
nni
.
dump
(
trainer
,
pickle_size_limit
=
pickle_size_limit
))
assert
any
(
isinstance
(
callback
,
EarlyStopping
)
for
callback
in
trainer
.
callbacks
)
assert
any
(
isinstance
(
callback
,
EarlyStopping
)
for
callback
in
trainer
.
callbacks
)
def
test_pickle_trainer
():
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
from
pytorch_lightning
import
Trainer
trainer
=
pl
.
Trainer
(
max_epochs
=
1
)
data
=
pickle
.
dumps
(
trainer
)
trainer
=
pickle
.
loads
(
data
)
assert
isinstance
(
trainer
,
Trainer
)
def
test_generator
():
def
test_generator
():
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
...
@@ -272,11 +325,31 @@ def test_arguments_kind():
...
@@ -272,11 +325,31 @@ def test_arguments_kind():
assert
lstm
.
trace_kwargs
==
{
'input_size'
:
2
,
'hidden_size'
:
2
}
assert
lstm
.
trace_kwargs
==
{
'input_size'
:
2
,
'hidden_size'
:
2
}
if
__name__
==
'__main__'
:
def
test_subclass
():
# test_simple_class()
@
nni
.
trace
# test_external_class()
class
Super
:
# test_nested_class()
def
__init__
(
self
,
a
,
b
):
# test_unserializable()
self
.
_a
=
a
# test_basic_unit()
self
.
_b
=
b
# test_generator()
test_arguments_kind
()
class
Sub1
(
Super
):
def
__init__
(
self
,
c
,
d
):
super
().
__init__
(
3
,
4
)
self
.
_c
=
c
self
.
_d
=
d
@
nni
.
trace
class
Sub2
(
Super
):
def
__init__
(
self
,
c
,
d
):
super
().
__init__
(
3
,
4
)
self
.
_c
=
c
self
.
_d
=
d
obj
=
Sub1
(
1
,
2
)
# There could be trace_kwargs for obj. Behavior is undefined.
assert
obj
.
_a
==
3
and
obj
.
_c
==
1
assert
isinstance
(
obj
,
Super
)
obj
=
Sub2
(
1
,
2
)
assert
obj
.
trace_kwargs
==
{
'c'
:
1
,
'd'
:
2
}
assert
issubclass
(
type
(
obj
),
Super
)
assert
isinstance
(
obj
,
Super
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment