"vscode:/vscode.git/clone" did not exist on "46f750ea2d3f4426bc35a9dd80fa30f5da754c92"
Unverified Commit 443ba8c1 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Serialization infrastructure V2 (#4337)

parent 896c516f
......@@ -114,7 +114,9 @@ CGO Execution
Utilities
---------
.. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.basic_unit
.. autofunction:: nni.retiarii.model_wrapper
.. autofunction:: nni.retiarii.fixed_arch
......
......@@ -78,3 +78,9 @@ Utilities
---------
.. autofunction:: nni.utils.merge_parameter
.. autofunction:: nni.trace
.. autofunction:: nni.dump
.. autofunction:: nni.load
......@@ -3,7 +3,7 @@ import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn
import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls
from nni.retiarii import model_wrapper, serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench101Cell
from nni.retiarii.strategy import Random
......@@ -82,7 +82,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target)
@serialize_cls
@nni.trace
class NasBench101TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4):
super().__init__()
......
......@@ -3,7 +3,7 @@ import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn
import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls
from nni.retiarii import model_wrapper, serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench201Cell
from nni.retiarii.strategy import Random
......@@ -71,7 +71,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target)
@serialize_cls
@nni.trace
class NasBench201TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=200, learning_rate=0.1, weight_decay=5e-4):
super().__init__()
......
......@@ -9,7 +9,7 @@ except ModuleNotFoundError:
from .runtime.log import init_logger
init_logger()
from .common.serializer import *
from .common.serializer import trace, dump, load
from .runtime.env_vars import dispatcher_env_vars
from .utils import ClassArgsValidator
......
......@@ -7,12 +7,12 @@ bohb_advisor.py
import sys
import math
import logging
import json_tricks
from schema import Schema, Optional
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
from ConfigSpace.read_and_write import pcs_new
import nni
from nni import ClassArgsValidator
from nni.runtime.protocol import CommandType, send
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
......@@ -428,7 +428,7 @@ class BOHB(MsgDispatcherBase):
'parameter_source': 'algorithm',
'parameters': ''
}
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret))
send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None
assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop(0)
......@@ -459,7 +459,7 @@ class BOHB(MsgDispatcherBase):
"""
ret = self._get_one_trial_job()
if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret))
send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1
def handle_update_search_space(self, data):
......@@ -536,7 +536,7 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
logger.debug('Tuner handle trial end, result is %s', data)
hyper_params = json_tricks.loads(data['hyper_params'])
hyper_params = nni.load(data['hyper_params'])
self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']]
......@@ -551,7 +551,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = one_unsatisfied['parameter_index']
# update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret))
send(CommandType.SendTrialJobParameter, nni.dump(ret))
for _ in range(self.credit):
self._request_one_trial_job()
......@@ -584,7 +584,7 @@ class BOHB(MsgDispatcherBase):
"""
logger.debug('handle report metric data = %s', data)
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
data['value'] = nni.load(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled()
assert data['trial_job_id'] is not None
......@@ -599,7 +599,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = data['parameter_index']
# update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret))
send(CommandType.SendTrialJobParameter, nni.dump(ret))
else:
assert 'value' in data
value = extract_scalar_reward(data['value'])
......@@ -655,7 +655,7 @@ class BOHB(MsgDispatcherBase):
data doesn't have required key 'parameter' and 'value'
"""
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
entry['value'] = nni.load(entry['value'])
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
......
......@@ -10,10 +10,10 @@ import logging
import math
import sys
import json_tricks
import numpy as np
from schema import Schema, Optional
import nni
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled
......@@ -336,7 +336,7 @@ class Hyperband(MsgDispatcherBase):
def _request_one_trial_job(self):
ret = self._get_one_trial_job()
if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret))
send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1
def _get_one_trial_job(self):
......@@ -365,7 +365,7 @@ class Hyperband(MsgDispatcherBase):
'parameter_source': 'algorithm',
'parameters': ''
}
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret))
send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None
assert self.generated_hyper_configs
......@@ -408,7 +408,7 @@ class Hyperband(MsgDispatcherBase):
event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
hyper_params = json_tricks.loads(data['hyper_params'])
hyper_params = nni.load(data['hyper_params'])
self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']]
......@@ -426,7 +426,7 @@ class Hyperband(MsgDispatcherBase):
Data type not supported
"""
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
data['value'] = nni.load(data['value'])
# multiphase? need to check
if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled()
......@@ -440,7 +440,7 @@ class Hyperband(MsgDispatcherBase):
if data['parameter_index'] is not None:
ret['parameter_index'] = data['parameter_index']
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret))
send(CommandType.SendTrialJobParameter, nni.dump(ret))
else:
value = extract_scalar_reward(data['value'])
bracket_id, i, _ = data['parameter_id'].split('_')
......
from .serializer import trace, dump, load, is_traceable
import abc
import copy
import collections.abc
import base64
import functools
import inspect
import numbers
import types
import warnings
from io import IOBase
from typing import Any, Union, Dict, Optional, List, TypeVar
import json_tricks # use json_tricks as serializer backend
import cloudpickle # use cloudpickle as backend for unserializable types and instances
__all__ = ['trace', 'dump', 'load', 'SerializableObject']
__all__ = ['trace', 'dump', 'load', 'Translatable', 'Traceable', 'is_traceable']
T = TypeVar('T')
class SerializableObject:
class Traceable(abc.ABC):
"""
A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
Dict returns a TraceDictType to enable serialization.
"""
@abc.abstractmethod
def trace_copy(self) -> 'Traceable':
"""
Perform a shallow copy.
NOTE: NONE of the attributes will be preserved.
This is the one that should be used when you want to "mutate" a serializable object.
"""
...
@property
@abc.abstractmethod
def trace_symbol(self) -> Any:
"""
Symbol object. Could be a class or a function.
``get_hybrid_cls_or_func_name`` and ``import_cls_or_func_from_hybrid_name`` is a pair to
convert the symbol into a string and convert the string back to symbol.
"""
...
@property
@abc.abstractmethod
def trace_args(self) -> List[Any]:
"""
List of positional arguments passed to symbol. Usually empty if ``kw_only`` is true,
in which case all the positional arguments are converted into keyword arguments.
"""
...
@property
@abc.abstractmethod
def trace_kwargs(self) -> Dict[str, Any]:
"""
Dict of keyword arguments.
"""
...
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@abc.abstractmethod
def _translate(self) -> Any:
pass
@staticmethod
def _translate_argument(d: Any) -> Any:
if isinstance(d, Translatable):
return d._translate()
return d
def is_traceable(obj: Any) -> bool:
"""
Check whether an object is a traceable instance (not type).
"""
return hasattr(obj, 'trace_copy') and \
hasattr(obj, 'trace_symbol') and \
hasattr(obj, 'trace_args') and \
hasattr(obj, 'trace_kwargs') and \
not inspect.isclass(obj)
class SerializableObject(Traceable):
"""
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
"""
def __init__(self, symbol: T, args: List[Any], kwargs: Dict[str, Any],
_self_contained: bool = False):
def __init__(self, symbol: T, args: List[Any], kwargs: Dict[str, Any], call_super: bool = False):
# use dict to avoid conflicts with user's getattr and setattr
self.__dict__['_nni_symbol'] = symbol
self.__dict__['_nni_args'] = args
self.__dict__['_nni_kwargs'] = kwargs
self.__dict__['_nni_call_super'] = call_super
self.__dict__['_nni_self_contained'] = _self_contained
if _self_contained:
# this is for internal usage only.
# kwargs is used to init the full object in the same object as this one, for simpler implementation.
super().__init__(*self._recursive_init(args), **self._recursive_init(kwargs))
def get(self) -> Any:
"""
Get the original object.
"""
if self._get_nni_attr('self_contained'):
return self
if '_nni_cache' not in self.__dict__:
self.__dict__['_nni_cache'] = self._get_nni_attr('symbol')(
*self._recursive_init(self._get_nni_attr('args')),
**self._recursive_init(self._get_nni_attr('kwargs'))
if call_super:
# call super means that the serializable object is by itself an object of the target class
super().__init__(
*[_argument_processor(arg) for arg in args],
**{kw: _argument_processor(arg) for kw, arg in kwargs.items()}
)
return self.__dict__['_nni_cache']
def copy(self) -> Union[T, 'SerializableObject']:
"""
Perform a shallow copy. Will throw away the self-contain property for classes (refer to implementation).
This is the one that should be used when you want to "mutate" a serializable object.
"""
def trace_copy(self) -> Union[T, 'SerializableObject']:
return SerializableObject(
self._get_nni_attr('symbol'),
self._get_nni_attr('args'),
self._get_nni_attr('kwargs')
self.trace_symbol,
[copy.copy(arg) for arg in self.trace_args],
{k: copy.copy(v) for k, v in self.trace_kwargs.items()},
)
def __json_encode__(self):
ret = {'__symbol__': _get_hybrid_cls_or_func_name(self._get_nni_attr('symbol'))}
if self._get_nni_attr('args'):
ret['__args__'] = self._get_nni_attr('args')
ret['__kwargs__'] = self._get_nni_attr('kwargs')
return ret
@property
def trace_symbol(self) -> Any:
return self._get_nni_attr('symbol')
@trace_symbol.setter
def trace_symbol(self, symbol: Any) -> None:
# for mutation purposes
self.__dict__['_nni_symbol'] = symbol
def _get_nni_attr(self, name):
@property
def trace_args(self) -> List[Any]:
return self._get_nni_attr('args')
@trace_args.setter
def trace_args(self, args: List[Any]):
self.__dict__['_nni_args'] = args
@property
def trace_kwargs(self) -> Dict[str, Any]:
return self._get_nni_attr('kwargs')
@trace_kwargs.setter
def trace_kwargs(self, kwargs: Dict[str, Any]):
self.__dict__['_nni_kwargs'] = kwargs
def _get_nni_attr(self, name: str) -> Any:
return self.__dict__['_nni_' + name]
def __repr__(self):
if self._get_nni_attr('self_contained'):
return repr(self)
if '_nni_cache' in self.__dict__:
return repr(self._get_nni_attr('cache'))
if self._get_nni_attr('call_super'):
return super().__repr__()
return 'SerializableObject(' + \
', '.join(['type=' + self._get_nni_attr('symbol').__name__] +
[repr(d) for d in self._get_nni_attr('args')] +
[k + '=' + repr(v) for k, v in self._get_nni_attr('kwargs').items()]) + \
')'
@staticmethod
def _recursive_init(d):
# auto-call get() to prevent type-converting in downstreaming functions
if isinstance(d, dict):
return {k: v.get() if isinstance(v, SerializableObject) else v for k, v in d.items()}
def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any:
# If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
def getter_factory(x):
return lambda self: self.__dict__['_nni_' + x]
def setter_factory(x):
def setter(self, val):
self.__dict__['_nni_' + x] = val
return setter
def trace_copy(self):
return SerializableObject(
self.trace_symbol,
[copy.copy(arg) for arg in self.trace_args],
{k: copy.copy(v) for k, v in self.trace_kwargs.items()},
)
attributes = {
'trace_symbol': property(getter_factory('symbol'), setter_factory('symbol')),
'trace_args': property(getter_factory('args'), setter_factory('args')),
'trace_kwargs': property(getter_factory('kwargs'), setter_factory('kwargs')),
'trace_copy': trace_copy
}
if hasattr(obj, '__class__') and hasattr(obj, '__dict__'):
for name, method in attributes.items():
setattr(obj.__class__, name, method)
else:
return [v.get() if isinstance(v, SerializableObject) else v for v in d]
wrapper = type('wrapper', (Traceable, type(obj)), attributes)
obj = wrapper(obj) # pylint: disable=abstract-class-instantiated
# make obj complying with the interface of traceable, though we cannot change its base class
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
return obj
def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, SerializableObject]:
def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]:
"""
Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios:
......@@ -98,16 +205,16 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab
When a class/function is annotated, all the instances/calls will return a object as it normally will.
Although the object might act like a normal object, it's actually a different object with NNI-specific properties.
To get the original object, you should use ``obj.get()`` to retrieve. The retrieved object can be used
like the original one, but there are still subtle differences in implementation.
Note that when using the result from a trace in another trace-able function/class, ``.get()`` is automatically
called, so that you don't have to worry about type-converting.
One exception is that if your function returns None, it will return an empty SerializableObject instead,
which should raise your attention when you want to check whether the None ``is None``.
Also it records extra information about where this object comes from. That's why it's called "trace".
When parameters of functions are received, it is first stored, and then a shallow copy will be passed to inner function.
This is to prevent mutable objects gets modified in the inner function.
When the function finished execution, we also record extra information about where this object comes from.
That's why it's called "trace".
When call ``nni.dump``, that information will be used, by default.
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspect the argument
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument
list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
Example:
......@@ -120,10 +227,18 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab
"""
def wrap(cls_or_func):
# already annotated, do nothing
if getattr(cls_or_func, '_traced', False):
return cls_or_func
if isinstance(cls_or_func, type):
return _trace_cls(cls_or_func, kw_only)
cls_or_func = _trace_cls(cls_or_func, kw_only)
elif _is_function(cls_or_func):
cls_or_func = _trace_func(cls_or_func, kw_only)
else:
return _trace_func(cls_or_func, kw_only)
raise TypeError(f'{cls_or_func} of type {type(cls_or_func)} is not supported to be traced. '
'File an issue at https://github.com/microsoft/nni/issues if you believe this is a mistake.')
cls_or_func._traced = True
return cls_or_func
# if we're being called as @trace()
if cls_or_func is None:
......@@ -133,8 +248,8 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab
return wrap(cls_or_func)
def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size_limit: int = 4096,
**json_tricks_kwargs) -> Union[str, bytes]:
def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_size_limit: int = 4096,
allow_nan: bool = True, **json_tricks_kwargs) -> Union[str, bytes]:
"""
Convert a nested data structure to a json string. Save to file if fp is specified.
Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle.
......@@ -143,10 +258,14 @@ def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size
Parameters
----------
obj : any
The object to dump.
fp : file handler or path
File to write to. Keep it none if you want to dump a string.
pickle_size_limit : int
This is set to avoid too long serialization result. Set to -1 to disable size check.
allow_nan : bool
Whether to allow nan to be serialized. Different from default value in json-tricks, our default value is true.
json_tricks_kwargs : dict
Other keyword arguments passed to json tricks (backend), e.g., indent=2.
......@@ -171,19 +290,32 @@ def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size
functools.partial(_json_tricks_any_object_encode, pickle_size_limit=pickle_size_limit),
]
json_tricks_kwargs['allow_nan'] = allow_nan
if fp is not None:
return json_tricks.dump(obj, fp, obj_encoders=encoders, **json_tricks_kwargs)
else:
return json_tricks.dumps(obj, obj_encoders=encoders, **json_tricks_kwargs)
def load(string: str = None, fp: Optional[Any] = None, **json_tricks_kwargs) -> Any:
def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comments: bool = True, **json_tricks_kwargs) -> Any:
"""
Load the string or from file, and convert it to a complex data structure.
At least one of string or fp has to be not none.
Parameters
----------
string : str
JSON string to parse. Can be set to none if fp is used.
fp : str
File path to load JSON from. Can be set to none if string is used.
ignore_comments : bool
Remove comments (starting with ``#`` or ``//``). Default is true.
Returns
-------
any
The loaded object.
"""
assert string is not None or fp is not None
# see encoders for explanation
......@@ -201,7 +333,12 @@ def load(string: str = None, fp: Optional[Any] = None, **json_tricks_kwargs) ->
_json_tricks_any_object_decode
]
# to bypass a deprecation warning in json-tricks
json_tricks_kwargs['ignore_comments'] = ignore_comments
if string is not None:
if isinstance(string, IOBase):
raise TypeError(f'Expect a string, found a {string}. If you intend to use a file, use `nni.load(fp=file)`')
return json_tricks.loads(string, obj_pairs_hooks=hooks, **json_tricks_kwargs)
else:
return json_tricks.load(fp, obj_pairs_hooks=hooks, **json_tricks_kwargs)
......@@ -214,11 +351,57 @@ def _trace_cls(base, kw_only):
class wrapper(SerializableObject, base):
def __init__(self, *args, **kwargs):
# store a copy of initial parameters
args, kwargs = _get_arguments_as_dict(base.__init__, args, kwargs, kw_only)
args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
# calling serializable object init to initialize the full object
super().__init__(symbol=base, args=args, kwargs=kwargs, _self_contained=True)
super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=True)
_copy_class_wrapper_attributes(base, wrapper)
return wrapper
def _trace_func(func, kw_only):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# similar to class, store parameters here
args, kwargs = _formulate_arguments(func, args, kwargs, kw_only)
# it's not clear whether this wrapper can handle all the types in python
# There are many cases here: https://docs.python.org/3/reference/datamodel.html
# but it looks that we have handled most commonly used cases
res = func(
*[_argument_processor(arg) for arg in args],
**{kw: _argument_processor(arg) for kw, arg in kwargs.items()}
)
if res is None:
# don't call super, makes no sense.
# an empty serializable object is "none". Don't check it though.
res = SerializableObject(func, args, kwargs, call_super=False)
elif hasattr(res, '__class__') and hasattr(res, '__dict__'):
# is a class, inject interface directly
# need to be done before primitive types because there could be inheritance here.
res = inject_trace_info(res, func, args, kwargs)
elif isinstance(res, (collections.abc.Callable, types.ModuleType, IOBase)):
raise TypeError(f'Try to add trace info to {res}, but functions and modules are not supported.')
elif isinstance(res, (numbers.Number, collections.abc.Sequence, collections.abc.Set, collections.abc.Mapping)):
# handle primitive types like int, str, set, dict, tuple
# NOTE: simple types including none, bool, int, float, list, tuple, dict
# will be directly captured by python json encoder
# and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation.
res = inject_trace_info(res, func, args, kwargs)
else:
raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. '
'Please file an issue at https://github.com/microsoft/nni/issues')
return res
return wrapper
def _copy_class_wrapper_attributes(base, wrapper):
_MISSING = '_missing'
for k in functools.WRAPPER_ASSIGNMENTS:
# assign magic attributes like __module__, __qualname__, __doc__
......@@ -229,25 +412,29 @@ def _trace_cls(base, kw_only):
except AttributeError:
pass
return wrapper
wrapper.__wrapped__ = base
def _trace_func(func, kw_only):
@functools.wraps
def wrapper(*args, **kwargs):
# similar to class, store parameters here
args, kwargs = _get_arguments_as_dict(func, args, kwargs, kw_only)
return SerializableObject(func, args, kwargs)
return wrapper
def _argument_processor(arg):
# 1) translate
# handle cases like ValueChoice
# This is needed because sometimes the recorded arguments are meant to be different from what the inner object receives.
arg = Translatable._translate_argument(arg)
# 2) prevent the stored parameters to be mutated by inner class.
# an example: https://github.com/microsoft/nni/issues/4329
if isinstance(arg, (collections.abc.MutableMapping, collections.abc.MutableSequence, collections.abc.MutableSet)):
arg = copy.copy(arg)
return arg
def _get_arguments_as_dict(func, args, kwargs, kw_only):
def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
# This is to formulate the arguments and make them well-formed.
if kw_only:
# get arguments passed to a function, and save it as a dict
argname_list = list(inspect.signature(func).parameters.keys())[1:]
argname_list = list(inspect.signature(func).parameters.keys())
if is_class_init:
argname_list = argname_list[1:]
full_args = {}
full_args.update(kwargs)
# match arguments with given arguments
# args should be longer than given list, because args can be used in a kwargs way
......@@ -255,9 +442,18 @@ def _get_arguments_as_dict(func, args, kwargs, kw_only):
for argname, value in zip(argname_list, args):
full_args[argname] = value
# use kwargs to override
full_args.update(kwargs)
args, kwargs = [], full_args
return args, kwargs
return list(args), kwargs
def _is_function(obj: Any) -> bool:
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return isinstance(obj, (types.FunctionType, types.BuiltinFunctionType, types.MethodType,
types.BuiltinMethodType))
def _import_cls_or_func_from_name(target: str) -> Any:
......@@ -268,6 +464,12 @@ def _import_cls_or_func_from_name(target: str) -> Any:
return getattr(module, identifier)
def _strip_trace_type(traceable: Any) -> Any:
if getattr(traceable, '_traced', False):
return traceable.__wrapped__
return traceable
def _get_cls_or_func_name(cls_or_func: Any) -> str:
module_name = cls_or_func.__module__
if module_name == '__main__':
......@@ -276,7 +478,8 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str:
try:
imported = _import_cls_or_func_from_name(full_name)
if imported != cls_or_func:
# ignores the differences in trace
if _strip_trace_type(imported) != _strip_trace_type(cls_or_func):
raise ImportError(f'Imported {imported} is not same as expected. The function might be dynamically created.')
except ImportError:
raise ImportError(f'Import {cls_or_func.__name__} from "{module_name}" failed.')
......@@ -284,12 +487,12 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str:
return full_name
def _get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096) -> str:
def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096) -> str:
try:
name = _get_cls_or_func_name(cls_or_func)
# import success, use a path format
return 'path:' + name
except ImportError:
except (ImportError, AttributeError):
b = cloudpickle.dumps(cls_or_func)
if len(b) > pickle_size_limit:
raise ValueError(f'Pickle too large when trying to dump {cls_or_func}. '
......@@ -298,7 +501,7 @@ def _get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096
return 'bytes:' + base64.b64encode(b).decode()
def _import_cls_or_func_from_hybrid_name(s: str) -> Any:
def import_cls_or_func_from_hybrid_name(s: str) -> Any:
if s.startswith('bytes:'):
b = base64.b64decode(s.split(':', 1)[-1])
return cloudpickle.loads(b)
......@@ -308,40 +511,47 @@ def _import_cls_or_func_from_hybrid_name(s: str) -> Any:
def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> str:
if not isinstance(cls_or_func, type) and not callable(cls_or_func):
if not isinstance(cls_or_func, type) and not _is_function(cls_or_func):
# not a function or class, continue
return cls_or_func
return {
'__nni_type__': _get_hybrid_cls_or_func_name(cls_or_func, pickle_size_limit)
'__nni_type__': get_hybrid_cls_or_func_name(cls_or_func, pickle_size_limit)
}
def _json_tricks_func_or_cls_decode(s: Dict[str, Any]) -> Any:
if isinstance(s, dict) and '__nni_type__' in s:
s = s['__nni_type__']
return _import_cls_or_func_from_hybrid_name(s)
return import_cls_or_func_from_hybrid_name(s)
return s
def _json_tricks_serializable_object_encode(obj: Any, primitives: bool = False, use_trace: bool = True) -> Dict[str, Any]:
# Encodes a serializable object instance to json.
# If primitives, the representation is simplified and cannot be recovered!
# do nothing to instance that is not a serializable object and do not use trace
if not use_trace or not isinstance(obj, SerializableObject):
if not use_trace or not is_traceable(obj):
return obj
return obj.__json_encode__()
if isinstance(obj.trace_symbol, property):
# commonly made mistake when users forget to call the traced function/class.
warnings.warn(f'The symbol of {obj} is found to be a property. Did you forget to create the instance with ``xx(...)``?')
ret = {'__symbol__': get_hybrid_cls_or_func_name(obj.trace_symbol)}
if obj.trace_args:
ret['__args__'] = obj.trace_args
if obj.trace_kwargs:
ret['__kwargs__'] = obj.trace_kwargs
return ret
def _json_tricks_serializable_object_decode(obj: Dict[str, Any]) -> Any:
if isinstance(obj, dict) and '__symbol__' in obj and '__kwargs__' in obj:
return SerializableObject(
_import_cls_or_func_from_hybrid_name(obj['__symbol__']),
getattr(obj, '__args__', []),
obj['__kwargs__']
)
if isinstance(obj, dict) and '__symbol__' in obj:
symbol = import_cls_or_func_from_hybrid_name(obj['__symbol__'])
args = obj.get('__args__', [])
kwargs = obj.get('__kwargs__', {})
return trace(symbol)(*args, **kwargs)
return obj
......@@ -353,8 +563,9 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
if hasattr(obj, '__class__') and (hasattr(obj, '__dict__') or hasattr(obj, '__slots__')):
b = cloudpickle.dumps(obj)
if len(b) > pickle_size_limit:
raise ValueError(f'Pickle too large when trying to dump {obj}. '
'Please try to raise pickle_size_limit if you insist.')
raise ValueError(f'Pickle too large when trying to dump {obj}. This might be caused by classes that are '
'not decorated by @nni.trace. Another option is to force bytes pickling and '
'try to raise pickle_size_limit.')
# use base64 to dump a bytes array
return {
'__nni_obj__': base64.b64encode(b).decode()
......
......@@ -6,11 +6,11 @@ from subprocess import Popen
import time
from typing import Optional, Union, List, overload, Any
import json_tricks
import colorama
import psutil
import nni.runtime.log
from nni.common import dump
from .config import ExperimentConfig, AlgorithmConfig
from .data import TrialJob, TrialMetricData, TrialResult
......@@ -439,7 +439,7 @@ class Experiment:
value: dict
New search_space.
"""
value = json_tricks.dumps(value)
value = dump(value)
self._update_experiment_profile('searchSpace', value)
def update_max_trial_number(self, value: int):
......
......@@ -6,4 +6,4 @@ from .graph import *
from .execution import *
from .fixed import fixed_arch
from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
from .serializer import basic_unit, model_wrapper, serialize, serialize_cls
......@@ -637,7 +637,7 @@ class GraphConverter:
original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_stop_parsing', False):
elif getattr(module, '_nni_basic_unit', False):
# this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module)
if m_attrs is not None:
......
......@@ -10,7 +10,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from ....serializer import serialize_cls
import nni
class BypassPlugin(TrainingTypePlugin):
......@@ -126,7 +126,7 @@ def get_accelerator_connector(
)
@serialize_cls
@nni.trace
class BypassAccelerator(Accelerator):
def __init__(self, precision_plugin=None, device="cpu", **trainer_kwargs):
if precision_plugin is None:
......
......@@ -14,10 +14,9 @@ import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer
from ....serializer import serialize_cls
@serialize_cls
@nni.trace
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0,
......@@ -126,7 +125,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
@serialize_cls
@nni.trace
class _ClassificationModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
......@@ -174,7 +173,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls
@nni.trace
class _RegressionModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytorch_lightning as pl
from ....serializer import serialize_cls
import nni
from .accelerator import BypassAccelerator
@serialize_cls
@nni.trace
class Trainer(pl.Trainer):
"""
Trainer for cross-graph optimization.
......
......@@ -10,17 +10,17 @@ import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader
import torch.utils.data as torch_data
import nni
from nni.common.serializer import is_traceable
try:
from .cgo import trainer as cgo_trainer
cgo_import_failed = False
except ImportError:
cgo_import_failed = True
from ...graph import Evaluator
from ...serializer import serialize_cls
from nni.retiarii.graph import Evaluator
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
......@@ -40,9 +40,10 @@ class LightningModule(pl.LightningModule):
self.model = model
Trainer = serialize_cls(pl.Trainer)
DataLoader = serialize_cls(DataLoader)
Trainer = nni.trace(pl.Trainer)
DataLoader = nni.trace(torch_data.DataLoader)
@nni.trace
class Lightning(Evaluator):
"""
Delegate the whole training to PyTorch Lightning.
......@@ -74,9 +75,10 @@ class Lightning(Evaluator):
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if cgo_import_failed:
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}'
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
else:
assert isinstance(trainer, Trainer) or isinstance(trainer, cgo_trainer.Trainer), \
# this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
......@@ -135,7 +137,7 @@ def _check_dataloader(dataloader):
return True
if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader])
return isinstance(dataloader, DataLoader)
return isinstance(dataloader, torch_data.DataLoader) and is_traceable(dataloader)
### The following are some commonly used Lightning modules ###
......@@ -219,7 +221,7 @@ class _AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target)
@serialize_cls
@nni.trace
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
......@@ -272,7 +274,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls
@nni.trace
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
......
......@@ -200,7 +200,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module._init_parameters.copy()
new_module_init_params = model.evaluator.module.trace_kwargs.copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model)
......
......@@ -4,13 +4,13 @@
import logging
from typing import Any, Callable
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
from .graph import MetricData
from .integration_api import register_advisor
from .serializer import json_dumps, json_loads
_logger = logging.getLogger(__name__)
......@@ -121,7 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'placement_constraint': placement_constraint
}
_logger.debug('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_dumps(new_trial))
send(CommandType.NewTrialJob, nni.dump(new_trial))
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
......@@ -140,7 +140,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data)
self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
......@@ -156,7 +156,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@staticmethod
def _process_value(value) -> Any: # hopefully a float
value = json_loads(value)
value = nni.load(value)
if isinstance(value, dict):
if 'default' in value:
return value['default']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
from typing import NewType, Any
import nni
from .serializer import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
......@@ -41,7 +38,6 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
"""
params = nni.get_next_parameter()
params = json_loads(json.dumps(params))
return params
......
......@@ -8,8 +8,9 @@ from typing import Any, List, Union, Dict, Optional
import torch
import torch.nn as nn
from ...serializer import Translatable, basic_unit
from ...utils import NoContextError
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import NoContextError
from .utils import generate_new_label, get_fixed_value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment