Unverified Commit 443ba8c1 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Serialization infrastructure V2 (#4337)

parent 896c516f
...@@ -114,7 +114,9 @@ CGO Execution ...@@ -114,7 +114,9 @@ CGO Execution
Utilities Utilities
--------- ---------
.. autofunction:: nni.retiarii.serialize .. autofunction:: nni.retiarii.basic_unit
.. autofunction:: nni.retiarii.model_wrapper
.. autofunction:: nni.retiarii.fixed_arch .. autofunction:: nni.retiarii.fixed_arch
......
...@@ -78,3 +78,9 @@ Utilities ...@@ -78,3 +78,9 @@ Utilities
--------- ---------
.. autofunction:: nni.utils.merge_parameter .. autofunction:: nni.utils.merge_parameter
.. autofunction:: nni.trace
.. autofunction:: nni.dump
.. autofunction:: nni.load
...@@ -3,7 +3,7 @@ import nni ...@@ -3,7 +3,7 @@ import nni
import nni.retiarii.evaluator.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn import torch.nn as nn
import torchmetrics import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls from nni.retiarii import model_wrapper, serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench101Cell from nni.retiarii.nn.pytorch import NasBench101Cell
from nni.retiarii.strategy import Random from nni.retiarii.strategy import Random
...@@ -82,7 +82,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy): ...@@ -82,7 +82,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target) return super().update(nn.functional.softmax(pred), target)
@serialize_cls @nni.trace
class NasBench101TrainingModule(pl.LightningModule): class NasBench101TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4): def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4):
super().__init__() super().__init__()
......
...@@ -3,7 +3,7 @@ import nni ...@@ -3,7 +3,7 @@ import nni
import nni.retiarii.evaluator.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn import torch.nn as nn
import torchmetrics import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls from nni.retiarii import model_wrapper, serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench201Cell from nni.retiarii.nn.pytorch import NasBench201Cell
from nni.retiarii.strategy import Random from nni.retiarii.strategy import Random
...@@ -71,7 +71,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy): ...@@ -71,7 +71,7 @@ class AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target) return super().update(nn.functional.softmax(pred), target)
@serialize_cls @nni.trace
class NasBench201TrainingModule(pl.LightningModule): class NasBench201TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=200, learning_rate=0.1, weight_decay=5e-4): def __init__(self, max_epochs=200, learning_rate=0.1, weight_decay=5e-4):
super().__init__() super().__init__()
......
...@@ -9,7 +9,7 @@ except ModuleNotFoundError: ...@@ -9,7 +9,7 @@ except ModuleNotFoundError:
from .runtime.log import init_logger from .runtime.log import init_logger
init_logger() init_logger()
from .common.serializer import * from .common.serializer import trace, dump, load
from .runtime.env_vars import dispatcher_env_vars from .runtime.env_vars import dispatcher_env_vars
from .utils import ClassArgsValidator from .utils import ClassArgsValidator
......
...@@ -7,12 +7,12 @@ bohb_advisor.py ...@@ -7,12 +7,12 @@ bohb_advisor.py
import sys import sys
import math import math
import logging import logging
import json_tricks
from schema import Schema, Optional from schema import Schema, Optional
import ConfigSpace as CS import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH import ConfigSpace.hyperparameters as CSH
from ConfigSpace.read_and_write import pcs_new from ConfigSpace.read_and_write import pcs_new
import nni
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
...@@ -428,7 +428,7 @@ class BOHB(MsgDispatcherBase): ...@@ -428,7 +428,7 @@ class BOHB(MsgDispatcherBase):
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
'parameters': '' 'parameters': ''
} }
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret)) send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None return None
assert self.generated_hyper_configs assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop(0) params = self.generated_hyper_configs.pop(0)
...@@ -459,7 +459,7 @@ class BOHB(MsgDispatcherBase): ...@@ -459,7 +459,7 @@ class BOHB(MsgDispatcherBase):
""" """
ret = self._get_one_trial_job() ret = self._get_one_trial_job()
if ret is not None: if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret)) send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1 self.credit -= 1
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
...@@ -536,7 +536,7 @@ class BOHB(MsgDispatcherBase): ...@@ -536,7 +536,7 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner hyper_params: the hyperparameters (a string) generated and returned by tuner
""" """
logger.debug('Tuner handle trial end, result is %s', data) logger.debug('Tuner handle trial end, result is %s', data)
hyper_params = json_tricks.loads(data['hyper_params']) hyper_params = nni.load(data['hyper_params'])
self._handle_trial_end(hyper_params['parameter_id']) self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map: if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']] del self.job_id_para_id_map[data['trial_job_id']]
...@@ -551,7 +551,7 @@ class BOHB(MsgDispatcherBase): ...@@ -551,7 +551,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = one_unsatisfied['parameter_index'] ret['parameter_index'] = one_unsatisfied['parameter_index']
# update parameter_id in self.job_id_para_id_map # update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id'] self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret)) send(CommandType.SendTrialJobParameter, nni.dump(ret))
for _ in range(self.credit): for _ in range(self.credit):
self._request_one_trial_job() self._request_one_trial_job()
...@@ -584,7 +584,7 @@ class BOHB(MsgDispatcherBase): ...@@ -584,7 +584,7 @@ class BOHB(MsgDispatcherBase):
""" """
logger.debug('handle report metric data = %s', data) logger.debug('handle report metric data = %s', data)
if 'value' in data: if 'value' in data:
data['value'] = json_tricks.loads(data['value']) data['value'] = nni.load(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
...@@ -599,7 +599,7 @@ class BOHB(MsgDispatcherBase): ...@@ -599,7 +599,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = data['parameter_index'] ret['parameter_index'] = data['parameter_index']
# update parameter_id in self.job_id_para_id_map # update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id'] self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret)) send(CommandType.SendTrialJobParameter, nni.dump(ret))
else: else:
assert 'value' in data assert 'value' in data
value = extract_scalar_reward(data['value']) value = extract_scalar_reward(data['value'])
...@@ -655,7 +655,7 @@ class BOHB(MsgDispatcherBase): ...@@ -655,7 +655,7 @@ class BOHB(MsgDispatcherBase):
data doesn't have required key 'parameter' and 'value' data doesn't have required key 'parameter' and 'value'
""" """
for entry in data: for entry in data:
entry['value'] = json_tricks.loads(entry['value']) entry['value'] = nni.load(entry['value'])
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data)) logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
......
...@@ -10,10 +10,10 @@ import logging ...@@ -10,10 +10,10 @@ import logging
import math import math
import sys import sys
import json_tricks
import numpy as np import numpy as np
from schema import Schema, Optional from schema import Schema, Optional
import nni
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled from nni.runtime.common import multi_phase_enabled
...@@ -336,7 +336,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -336,7 +336,7 @@ class Hyperband(MsgDispatcherBase):
def _request_one_trial_job(self): def _request_one_trial_job(self):
ret = self._get_one_trial_job() ret = self._get_one_trial_job()
if ret is not None: if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret)) send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1 self.credit -= 1
def _get_one_trial_job(self): def _get_one_trial_job(self):
...@@ -365,7 +365,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -365,7 +365,7 @@ class Hyperband(MsgDispatcherBase):
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
'parameters': '' 'parameters': ''
} }
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret)) send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None return None
assert self.generated_hyper_configs assert self.generated_hyper_configs
...@@ -408,7 +408,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -408,7 +408,7 @@ class Hyperband(MsgDispatcherBase):
event: the job's state event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner hyper_params: the hyperparameters (a string) generated and returned by tuner
""" """
hyper_params = json_tricks.loads(data['hyper_params']) hyper_params = nni.load(data['hyper_params'])
self._handle_trial_end(hyper_params['parameter_id']) self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map: if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']] del self.job_id_para_id_map[data['trial_job_id']]
...@@ -426,7 +426,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -426,7 +426,7 @@ class Hyperband(MsgDispatcherBase):
Data type not supported Data type not supported
""" """
if 'value' in data: if 'value' in data:
data['value'] = json_tricks.loads(data['value']) data['value'] = nni.load(data['value'])
# multiphase? need to check # multiphase? need to check
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
...@@ -440,7 +440,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -440,7 +440,7 @@ class Hyperband(MsgDispatcherBase):
if data['parameter_index'] is not None: if data['parameter_index'] is not None:
ret['parameter_index'] = data['parameter_index'] ret['parameter_index'] = data['parameter_index']
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id'] self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, json_tricks.dumps(ret)) send(CommandType.SendTrialJobParameter, nni.dump(ret))
else: else:
value = extract_scalar_reward(data['value']) value = extract_scalar_reward(data['value'])
bracket_id, i, _ = data['parameter_id'].split('_') bracket_id, i, _ = data['parameter_id'].split('_')
......
from .serializer import trace, dump, load, is_traceable
import abc
import copy
import collections.abc
import base64 import base64
import functools import functools
import inspect import inspect
import numbers
import types
import warnings
from io import IOBase
from typing import Any, Union, Dict, Optional, List, TypeVar from typing import Any, Union, Dict, Optional, List, TypeVar
import json_tricks # use json_tricks as serializer backend import json_tricks # use json_tricks as serializer backend
import cloudpickle # use cloudpickle as backend for unserializable types and instances import cloudpickle # use cloudpickle as backend for unserializable types and instances
__all__ = ['trace', 'dump', 'load', 'SerializableObject'] __all__ = ['trace', 'dump', 'load', 'Translatable', 'Traceable', 'is_traceable']
T = TypeVar('T') T = TypeVar('T')
class SerializableObject: class Traceable(abc.ABC):
"""
A traceable object have copy and dict. Copy and mutate are used to copy the object for further mutations.
Dict returns a TraceDictType to enable serialization.
"""
@abc.abstractmethod
def trace_copy(self) -> 'Traceable':
"""
Perform a shallow copy.
NOTE: NONE of the attributes will be preserved.
This is the one that should be used when you want to "mutate" a serializable object.
"""
...
@property
@abc.abstractmethod
def trace_symbol(self) -> Any:
"""
Symbol object. Could be a class or a function.
``get_hybrid_cls_or_func_name`` and ``import_cls_or_func_from_hybrid_name`` is a pair to
convert the symbol into a string and convert the string back to symbol.
"""
...
@property
@abc.abstractmethod
def trace_args(self) -> List[Any]:
"""
List of positional arguments passed to symbol. Usually empty if ``kw_only`` is true,
in which case all the positional arguments are converted into keyword arguments.
"""
...
@property
@abc.abstractmethod
def trace_kwargs(self) -> Dict[str, Any]:
"""
Dict of keyword arguments.
"""
...
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@abc.abstractmethod
def _translate(self) -> Any:
pass
@staticmethod
def _translate_argument(d: Any) -> Any:
if isinstance(d, Translatable):
return d._translate()
return d
def is_traceable(obj: Any) -> bool:
"""
Check whether an object is a traceable instance (not type).
"""
return hasattr(obj, 'trace_copy') and \
hasattr(obj, 'trace_symbol') and \
hasattr(obj, 'trace_args') and \
hasattr(obj, 'trace_kwargs') and \
not inspect.isclass(obj)
class SerializableObject(Traceable):
""" """
Serializable object is a wrapper of existing python objects, that supports dump and load easily. Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``. Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
""" """
def __init__(self, symbol: T, args: List[Any], kwargs: Dict[str, Any], def __init__(self, symbol: T, args: List[Any], kwargs: Dict[str, Any], call_super: bool = False):
_self_contained: bool = False):
# use dict to avoid conflicts with user's getattr and setattr # use dict to avoid conflicts with user's getattr and setattr
self.__dict__['_nni_symbol'] = symbol self.__dict__['_nni_symbol'] = symbol
self.__dict__['_nni_args'] = args self.__dict__['_nni_args'] = args
self.__dict__['_nni_kwargs'] = kwargs self.__dict__['_nni_kwargs'] = kwargs
self.__dict__['_nni_call_super'] = call_super
self.__dict__['_nni_self_contained'] = _self_contained if call_super:
# call super means that the serializable object is by itself an object of the target class
if _self_contained: super().__init__(
# this is for internal usage only. *[_argument_processor(arg) for arg in args],
# kwargs is used to init the full object in the same object as this one, for simpler implementation. **{kw: _argument_processor(arg) for kw, arg in kwargs.items()}
super().__init__(*self._recursive_init(args), **self._recursive_init(kwargs))
def get(self) -> Any:
"""
Get the original object.
"""
if self._get_nni_attr('self_contained'):
return self
if '_nni_cache' not in self.__dict__:
self.__dict__['_nni_cache'] = self._get_nni_attr('symbol')(
*self._recursive_init(self._get_nni_attr('args')),
**self._recursive_init(self._get_nni_attr('kwargs'))
) )
return self.__dict__['_nni_cache']
def copy(self) -> Union[T, 'SerializableObject']: def trace_copy(self) -> Union[T, 'SerializableObject']:
"""
Perform a shallow copy. Will throw away the self-contain property for classes (refer to implementation).
This is the one that should be used when you want to "mutate" a serializable object.
"""
return SerializableObject( return SerializableObject(
self._get_nni_attr('symbol'), self.trace_symbol,
self._get_nni_attr('args'), [copy.copy(arg) for arg in self.trace_args],
self._get_nni_attr('kwargs') {k: copy.copy(v) for k, v in self.trace_kwargs.items()},
) )
def __json_encode__(self): @property
ret = {'__symbol__': _get_hybrid_cls_or_func_name(self._get_nni_attr('symbol'))} def trace_symbol(self) -> Any:
if self._get_nni_attr('args'): return self._get_nni_attr('symbol')
ret['__args__'] = self._get_nni_attr('args')
ret['__kwargs__'] = self._get_nni_attr('kwargs') @trace_symbol.setter
return ret def trace_symbol(self, symbol: Any) -> None:
# for mutation purposes
self.__dict__['_nni_symbol'] = symbol
def _get_nni_attr(self, name): @property
def trace_args(self) -> List[Any]:
return self._get_nni_attr('args')
@trace_args.setter
def trace_args(self, args: List[Any]):
self.__dict__['_nni_args'] = args
@property
def trace_kwargs(self) -> Dict[str, Any]:
return self._get_nni_attr('kwargs')
@trace_kwargs.setter
def trace_kwargs(self, kwargs: Dict[str, Any]):
self.__dict__['_nni_kwargs'] = kwargs
def _get_nni_attr(self, name: str) -> Any:
return self.__dict__['_nni_' + name] return self.__dict__['_nni_' + name]
def __repr__(self): def __repr__(self):
if self._get_nni_attr('self_contained'): if self._get_nni_attr('call_super'):
return repr(self) return super().__repr__()
if '_nni_cache' in self.__dict__:
return repr(self._get_nni_attr('cache'))
return 'SerializableObject(' + \ return 'SerializableObject(' + \
', '.join(['type=' + self._get_nni_attr('symbol').__name__] + ', '.join(['type=' + self._get_nni_attr('symbol').__name__] +
[repr(d) for d in self._get_nni_attr('args')] + [repr(d) for d in self._get_nni_attr('args')] +
[k + '=' + repr(v) for k, v in self._get_nni_attr('kwargs').items()]) + \ [k + '=' + repr(v) for k, v in self._get_nni_attr('kwargs').items()]) + \
')' ')'
@staticmethod
def _recursive_init(d): def inject_trace_info(obj: Any, symbol: T, args: List[Any], kwargs: Dict[str, Any]) -> Any:
# auto-call get() to prevent type-converting in downstreaming functions # If an object is already created, this can be a fix so that the necessary info are re-injected into the object.
if isinstance(d, dict):
return {k: v.get() if isinstance(v, SerializableObject) else v for k, v in d.items()} def getter_factory(x):
return lambda self: self.__dict__['_nni_' + x]
def setter_factory(x):
def setter(self, val):
self.__dict__['_nni_' + x] = val
return setter
def trace_copy(self):
return SerializableObject(
self.trace_symbol,
[copy.copy(arg) for arg in self.trace_args],
{k: copy.copy(v) for k, v in self.trace_kwargs.items()},
)
attributes = {
'trace_symbol': property(getter_factory('symbol'), setter_factory('symbol')),
'trace_args': property(getter_factory('args'), setter_factory('args')),
'trace_kwargs': property(getter_factory('kwargs'), setter_factory('kwargs')),
'trace_copy': trace_copy
}
if hasattr(obj, '__class__') and hasattr(obj, '__dict__'):
for name, method in attributes.items():
setattr(obj.__class__, name, method)
else: else:
return [v.get() if isinstance(v, SerializableObject) else v for v in d] wrapper = type('wrapper', (Traceable, type(obj)), attributes)
obj = wrapper(obj) # pylint: disable=abstract-class-instantiated
# make obj complying with the interface of traceable, though we cannot change its base class
obj.__dict__.update(_nni_symbol=symbol, _nni_args=args, _nni_kwargs=kwargs)
return obj
def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, SerializableObject]:
def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]:
""" """
Annotate a function or a class if you want to preserve where it comes from. Annotate a function or a class if you want to preserve where it comes from.
This is usually used in the following scenarios: This is usually used in the following scenarios:
...@@ -98,16 +205,16 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab ...@@ -98,16 +205,16 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab
When a class/function is annotated, all the instances/calls will return a object as it normally will. When a class/function is annotated, all the instances/calls will return a object as it normally will.
Although the object might act like a normal object, it's actually a different object with NNI-specific properties. Although the object might act like a normal object, it's actually a different object with NNI-specific properties.
To get the original object, you should use ``obj.get()`` to retrieve. The retrieved object can be used One exception is that if your function returns None, it will return an empty SerializableObject instead,
like the original one, but there are still subtle differences in implementation. which should raise your attention when you want to check whether the None ``is None``.
Note that when using the result from a trace in another trace-able function/class, ``.get()`` is automatically
called, so that you don't have to worry about type-converting.
Also it records extra information about where this object comes from. That's why it's called "trace". When parameters of functions are received, it is first stored, and then a shallow copy will be passed to inner function.
This is to prevent mutable objects gets modified in the inner function.
When the function finished execution, we also record extra information about where this object comes from.
That's why it's called "trace".
When call ``nni.dump``, that information will be used, by default. When call ``nni.dump``, that information will be used, by default.
If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspect the argument If ``kw_only`` is true, try to convert all parameters into kwargs type. This is done by inspecting the argument
list and types. This can be useful to extract semantics, but can be tricky in some corner cases. list and types. This can be useful to extract semantics, but can be tricky in some corner cases.
Example: Example:
...@@ -120,10 +227,18 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab ...@@ -120,10 +227,18 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab
""" """
def wrap(cls_or_func): def wrap(cls_or_func):
# already annotated, do nothing
if getattr(cls_or_func, '_traced', False):
return cls_or_func
if isinstance(cls_or_func, type): if isinstance(cls_or_func, type):
return _trace_cls(cls_or_func, kw_only) cls_or_func = _trace_cls(cls_or_func, kw_only)
elif _is_function(cls_or_func):
cls_or_func = _trace_func(cls_or_func, kw_only)
else: else:
return _trace_func(cls_or_func, kw_only) raise TypeError(f'{cls_or_func} of type {type(cls_or_func)} is not supported to be traced. '
'File an issue at https://github.com/microsoft/nni/issues if you believe this is a mistake.')
cls_or_func._traced = True
return cls_or_func
# if we're being called as @trace() # if we're being called as @trace()
if cls_or_func is None: if cls_or_func is None:
...@@ -133,8 +248,8 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab ...@@ -133,8 +248,8 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Serializab
return wrap(cls_or_func) return wrap(cls_or_func)
def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size_limit: int = 4096, def dump(obj: Any, fp: Optional[Any] = None, *, use_trace: bool = True, pickle_size_limit: int = 4096,
**json_tricks_kwargs) -> Union[str, bytes]: allow_nan: bool = True, **json_tricks_kwargs) -> Union[str, bytes]:
""" """
Convert a nested data structure to a json string. Save to file if fp is specified. Convert a nested data structure to a json string. Save to file if fp is specified.
Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle. Use json-tricks as main backend. For unhandled cases in json-tricks, use cloudpickle.
...@@ -143,10 +258,14 @@ def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size ...@@ -143,10 +258,14 @@ def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size
Parameters Parameters
---------- ----------
obj : any
The object to dump.
fp : file handler or path fp : file handler or path
File to write to. Keep it none if you want to dump a string. File to write to. Keep it none if you want to dump a string.
pickle_size_limit : int pickle_size_limit : int
This is set to avoid too long serialization result. Set to -1 to disable size check. This is set to avoid too long serialization result. Set to -1 to disable size check.
allow_nan : bool
Whether to allow nan to be serialized. Different from default value in json-tricks, our default value is true.
json_tricks_kwargs : dict json_tricks_kwargs : dict
Other keyword arguments passed to json tricks (backend), e.g., indent=2. Other keyword arguments passed to json tricks (backend), e.g., indent=2.
...@@ -171,19 +290,32 @@ def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size ...@@ -171,19 +290,32 @@ def dump(obj: Any, fp: Optional[Any] = None, use_trace: bool = True, pickle_size
functools.partial(_json_tricks_any_object_encode, pickle_size_limit=pickle_size_limit), functools.partial(_json_tricks_any_object_encode, pickle_size_limit=pickle_size_limit),
] ]
json_tricks_kwargs['allow_nan'] = allow_nan
if fp is not None: if fp is not None:
return json_tricks.dump(obj, fp, obj_encoders=encoders, **json_tricks_kwargs) return json_tricks.dump(obj, fp, obj_encoders=encoders, **json_tricks_kwargs)
else: else:
return json_tricks.dumps(obj, obj_encoders=encoders, **json_tricks_kwargs) return json_tricks.dumps(obj, obj_encoders=encoders, **json_tricks_kwargs)
def load(string: str = None, fp: Optional[Any] = None, **json_tricks_kwargs) -> Any: def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comments: bool = True, **json_tricks_kwargs) -> Any:
""" """
Load the string or from file, and convert it to a complex data structure. Load the string or from file, and convert it to a complex data structure.
At least one of string or fp has to be not none. At least one of string or fp has to be not none.
Parameters Parameters
---------- ----------
string : str
JSON string to parse. Can be set to none if fp is used.
fp : str
File path to load JSON from. Can be set to none if string is used.
ignore_comments : bool
Remove comments (starting with ``#`` or ``//``). Default is true.
Returns
-------
any
The loaded object.
""" """
assert string is not None or fp is not None assert string is not None or fp is not None
# see encoders for explanation # see encoders for explanation
...@@ -201,7 +333,12 @@ def load(string: str = None, fp: Optional[Any] = None, **json_tricks_kwargs) -> ...@@ -201,7 +333,12 @@ def load(string: str = None, fp: Optional[Any] = None, **json_tricks_kwargs) ->
_json_tricks_any_object_decode _json_tricks_any_object_decode
] ]
# to bypass a deprecation warning in json-tricks
json_tricks_kwargs['ignore_comments'] = ignore_comments
if string is not None: if string is not None:
if isinstance(string, IOBase):
raise TypeError(f'Expect a string, found a {string}. If you intend to use a file, use `nni.load(fp=file)`')
return json_tricks.loads(string, obj_pairs_hooks=hooks, **json_tricks_kwargs) return json_tricks.loads(string, obj_pairs_hooks=hooks, **json_tricks_kwargs)
else: else:
return json_tricks.load(fp, obj_pairs_hooks=hooks, **json_tricks_kwargs) return json_tricks.load(fp, obj_pairs_hooks=hooks, **json_tricks_kwargs)
...@@ -214,11 +351,57 @@ def _trace_cls(base, kw_only): ...@@ -214,11 +351,57 @@ def _trace_cls(base, kw_only):
class wrapper(SerializableObject, base): class wrapper(SerializableObject, base):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# store a copy of initial parameters # store a copy of initial parameters
args, kwargs = _get_arguments_as_dict(base.__init__, args, kwargs, kw_only) args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
# calling serializable object init to initialize the full object # calling serializable object init to initialize the full object
super().__init__(symbol=base, args=args, kwargs=kwargs, _self_contained=True) super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=True)
_copy_class_wrapper_attributes(base, wrapper)
return wrapper
def _trace_func(func, kw_only):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# similar to class, store parameters here
args, kwargs = _formulate_arguments(func, args, kwargs, kw_only)
# it's not clear whether this wrapper can handle all the types in python
# There are many cases here: https://docs.python.org/3/reference/datamodel.html
# but it looks that we have handled most commonly used cases
res = func(
*[_argument_processor(arg) for arg in args],
**{kw: _argument_processor(arg) for kw, arg in kwargs.items()}
)
if res is None:
# don't call super, makes no sense.
# an empty serializable object is "none". Don't check it though.
res = SerializableObject(func, args, kwargs, call_super=False)
elif hasattr(res, '__class__') and hasattr(res, '__dict__'):
# is a class, inject interface directly
# need to be done before primitive types because there could be inheritance here.
res = inject_trace_info(res, func, args, kwargs)
elif isinstance(res, (collections.abc.Callable, types.ModuleType, IOBase)):
raise TypeError(f'Try to add trace info to {res}, but functions and modules are not supported.')
elif isinstance(res, (numbers.Number, collections.abc.Sequence, collections.abc.Set, collections.abc.Mapping)):
# handle primitive types like int, str, set, dict, tuple
# NOTE: simple types including none, bool, int, float, list, tuple, dict
# will be directly captured by python json encoder
# and thus not possible to restore the trace parameters after dump and reload.
# this is a known limitation.
res = inject_trace_info(res, func, args, kwargs)
else:
raise TypeError(f'Try to add trace info to {res}, but the type "{type(res)}" is unknown. '
'Please file an issue at https://github.com/microsoft/nni/issues')
return res
return wrapper
def _copy_class_wrapper_attributes(base, wrapper):
_MISSING = '_missing' _MISSING = '_missing'
for k in functools.WRAPPER_ASSIGNMENTS: for k in functools.WRAPPER_ASSIGNMENTS:
# assign magic attributes like __module__, __qualname__, __doc__ # assign magic attributes like __module__, __qualname__, __doc__
...@@ -229,25 +412,29 @@ def _trace_cls(base, kw_only): ...@@ -229,25 +412,29 @@ def _trace_cls(base, kw_only):
except AttributeError: except AttributeError:
pass pass
return wrapper wrapper.__wrapped__ = base
def _trace_func(func, kw_only): def _argument_processor(arg):
@functools.wraps # 1) translate
def wrapper(*args, **kwargs): # handle cases like ValueChoice
# similar to class, store parameters here # This is needed because sometimes the recorded arguments are meant to be different from what the inner object receives.
args, kwargs = _get_arguments_as_dict(func, args, kwargs, kw_only) arg = Translatable._translate_argument(arg)
return SerializableObject(func, args, kwargs) # 2) prevent the stored parameters to be mutated by inner class.
# an example: https://github.com/microsoft/nni/issues/4329
return wrapper if isinstance(arg, (collections.abc.MutableMapping, collections.abc.MutableSequence, collections.abc.MutableSet)):
arg = copy.copy(arg)
return arg
def _get_arguments_as_dict(func, args, kwargs, kw_only): def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
# This is to formulate the arguments and make them well-formed.
if kw_only: if kw_only:
# get arguments passed to a function, and save it as a dict # get arguments passed to a function, and save it as a dict
argname_list = list(inspect.signature(func).parameters.keys())[1:] argname_list = list(inspect.signature(func).parameters.keys())
if is_class_init:
argname_list = argname_list[1:]
full_args = {} full_args = {}
full_args.update(kwargs)
# match arguments with given arguments # match arguments with given arguments
# args should be longer than given list, because args can be used in a kwargs way # args should be longer than given list, because args can be used in a kwargs way
...@@ -255,9 +442,18 @@ def _get_arguments_as_dict(func, args, kwargs, kw_only): ...@@ -255,9 +442,18 @@ def _get_arguments_as_dict(func, args, kwargs, kw_only):
for argname, value in zip(argname_list, args): for argname, value in zip(argname_list, args):
full_args[argname] = value full_args[argname] = value
# use kwargs to override
full_args.update(kwargs)
args, kwargs = [], full_args args, kwargs = [], full_args
return args, kwargs return list(args), kwargs
def _is_function(obj: Any) -> bool:
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return isinstance(obj, (types.FunctionType, types.BuiltinFunctionType, types.MethodType,
types.BuiltinMethodType))
def _import_cls_or_func_from_name(target: str) -> Any: def _import_cls_or_func_from_name(target: str) -> Any:
...@@ -268,6 +464,12 @@ def _import_cls_or_func_from_name(target: str) -> Any: ...@@ -268,6 +464,12 @@ def _import_cls_or_func_from_name(target: str) -> Any:
return getattr(module, identifier) return getattr(module, identifier)
def _strip_trace_type(traceable: Any) -> Any:
if getattr(traceable, '_traced', False):
return traceable.__wrapped__
return traceable
def _get_cls_or_func_name(cls_or_func: Any) -> str: def _get_cls_or_func_name(cls_or_func: Any) -> str:
module_name = cls_or_func.__module__ module_name = cls_or_func.__module__
if module_name == '__main__': if module_name == '__main__':
...@@ -276,7 +478,8 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str: ...@@ -276,7 +478,8 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str:
try: try:
imported = _import_cls_or_func_from_name(full_name) imported = _import_cls_or_func_from_name(full_name)
if imported != cls_or_func: # ignores the differences in trace
if _strip_trace_type(imported) != _strip_trace_type(cls_or_func):
raise ImportError(f'Imported {imported} is not same as expected. The function might be dynamically created.') raise ImportError(f'Imported {imported} is not same as expected. The function might be dynamically created.')
except ImportError: except ImportError:
raise ImportError(f'Import {cls_or_func.__name__} from "{module_name}" failed.') raise ImportError(f'Import {cls_or_func.__name__} from "{module_name}" failed.')
...@@ -284,12 +487,12 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str: ...@@ -284,12 +487,12 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str:
return full_name return full_name
def _get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096) -> str: def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096) -> str:
try: try:
name = _get_cls_or_func_name(cls_or_func) name = _get_cls_or_func_name(cls_or_func)
# import success, use a path format # import success, use a path format
return 'path:' + name return 'path:' + name
except ImportError: except (ImportError, AttributeError):
b = cloudpickle.dumps(cls_or_func) b = cloudpickle.dumps(cls_or_func)
if len(b) > pickle_size_limit: if len(b) > pickle_size_limit:
raise ValueError(f'Pickle too large when trying to dump {cls_or_func}. ' raise ValueError(f'Pickle too large when trying to dump {cls_or_func}. '
...@@ -298,7 +501,7 @@ def _get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096 ...@@ -298,7 +501,7 @@ def _get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096
return 'bytes:' + base64.b64encode(b).decode() return 'bytes:' + base64.b64encode(b).decode()
def _import_cls_or_func_from_hybrid_name(s: str) -> Any: def import_cls_or_func_from_hybrid_name(s: str) -> Any:
if s.startswith('bytes:'): if s.startswith('bytes:'):
b = base64.b64decode(s.split(':', 1)[-1]) b = base64.b64decode(s.split(':', 1)[-1])
return cloudpickle.loads(b) return cloudpickle.loads(b)
...@@ -308,40 +511,47 @@ def _import_cls_or_func_from_hybrid_name(s: str) -> Any: ...@@ -308,40 +511,47 @@ def _import_cls_or_func_from_hybrid_name(s: str) -> Any:
def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> str: def _json_tricks_func_or_cls_encode(cls_or_func: Any, primitives: bool = False, pickle_size_limit: int = 4096) -> str:
if not isinstance(cls_or_func, type) and not callable(cls_or_func): if not isinstance(cls_or_func, type) and not _is_function(cls_or_func):
# not a function or class, continue # not a function or class, continue
return cls_or_func return cls_or_func
return { return {
'__nni_type__': _get_hybrid_cls_or_func_name(cls_or_func, pickle_size_limit) '__nni_type__': get_hybrid_cls_or_func_name(cls_or_func, pickle_size_limit)
} }
def _json_tricks_func_or_cls_decode(s: Dict[str, Any]) -> Any: def _json_tricks_func_or_cls_decode(s: Dict[str, Any]) -> Any:
if isinstance(s, dict) and '__nni_type__' in s: if isinstance(s, dict) and '__nni_type__' in s:
s = s['__nni_type__'] s = s['__nni_type__']
return _import_cls_or_func_from_hybrid_name(s) return import_cls_or_func_from_hybrid_name(s)
return s return s
def _json_tricks_serializable_object_encode(obj: Any, primitives: bool = False, use_trace: bool = True) -> Dict[str, Any]: def _json_tricks_serializable_object_encode(obj: Any, primitives: bool = False, use_trace: bool = True) -> Dict[str, Any]:
# Encodes a serializable object instance to json. # Encodes a serializable object instance to json.
# If primitives, the representation is simplified and cannot be recovered!
# do nothing to instance that is not a serializable object and do not use trace # do nothing to instance that is not a serializable object and do not use trace
if not use_trace or not isinstance(obj, SerializableObject): if not use_trace or not is_traceable(obj):
return obj return obj
return obj.__json_encode__() if isinstance(obj.trace_symbol, property):
# commonly made mistake when users forget to call the traced function/class.
warnings.warn(f'The symbol of {obj} is found to be a property. Did you forget to create the instance with ``xx(...)``?')
ret = {'__symbol__': get_hybrid_cls_or_func_name(obj.trace_symbol)}
if obj.trace_args:
ret['__args__'] = obj.trace_args
if obj.trace_kwargs:
ret['__kwargs__'] = obj.trace_kwargs
return ret
def _json_tricks_serializable_object_decode(obj: Dict[str, Any]) -> Any: def _json_tricks_serializable_object_decode(obj: Dict[str, Any]) -> Any:
if isinstance(obj, dict) and '__symbol__' in obj and '__kwargs__' in obj: if isinstance(obj, dict) and '__symbol__' in obj:
return SerializableObject( symbol = import_cls_or_func_from_hybrid_name(obj['__symbol__'])
_import_cls_or_func_from_hybrid_name(obj['__symbol__']), args = obj.get('__args__', [])
getattr(obj, '__args__', []), kwargs = obj.get('__kwargs__', {})
obj['__kwargs__'] return trace(symbol)(*args, **kwargs)
)
return obj return obj
...@@ -353,8 +563,9 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si ...@@ -353,8 +563,9 @@ def _json_tricks_any_object_encode(obj: Any, primitives: bool = False, pickle_si
if hasattr(obj, '__class__') and (hasattr(obj, '__dict__') or hasattr(obj, '__slots__')): if hasattr(obj, '__class__') and (hasattr(obj, '__dict__') or hasattr(obj, '__slots__')):
b = cloudpickle.dumps(obj) b = cloudpickle.dumps(obj)
if len(b) > pickle_size_limit: if len(b) > pickle_size_limit:
raise ValueError(f'Pickle too large when trying to dump {obj}. ' raise ValueError(f'Pickle too large when trying to dump {obj}. This might be caused by classes that are '
'Please try to raise pickle_size_limit if you insist.') 'not decorated by @nni.trace. Another option is to force bytes pickling and '
'try to raise pickle_size_limit.')
# use base64 to dump a bytes array # use base64 to dump a bytes array
return { return {
'__nni_obj__': base64.b64encode(b).decode() '__nni_obj__': base64.b64encode(b).decode()
......
...@@ -6,11 +6,11 @@ from subprocess import Popen ...@@ -6,11 +6,11 @@ from subprocess import Popen
import time import time
from typing import Optional, Union, List, overload, Any from typing import Optional, Union, List, overload, Any
import json_tricks
import colorama import colorama
import psutil import psutil
import nni.runtime.log import nni.runtime.log
from nni.common import dump
from .config import ExperimentConfig, AlgorithmConfig from .config import ExperimentConfig, AlgorithmConfig
from .data import TrialJob, TrialMetricData, TrialResult from .data import TrialJob, TrialMetricData, TrialResult
...@@ -439,7 +439,7 @@ class Experiment: ...@@ -439,7 +439,7 @@ class Experiment:
value: dict value: dict
New search_space. New search_space.
""" """
value = json_tricks.dumps(value) value = dump(value)
self._update_experiment_profile('searchSpace', value) self._update_experiment_profile('searchSpace', value)
def update_max_trial_number(self, value: int): def update_max_trial_number(self, value: int):
......
...@@ -6,4 +6,4 @@ from .graph import * ...@@ -6,4 +6,4 @@ from .graph import *
from .execution import * from .execution import *
from .fixed import fixed_arch from .fixed import fixed_arch
from .mutator import * from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper from .serializer import basic_unit, model_wrapper, serialize, serialize_cls
...@@ -637,7 +637,7 @@ class GraphConverter: ...@@ -637,7 +637,7 @@ class GraphConverter:
original_type_name not in MODULE_EXCEPT_LIST: original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph # this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module) m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_stop_parsing', False): elif getattr(module, '_nni_basic_unit', False):
# this module is marked as serialize, won't continue to parse # this module is marked as serialize, won't continue to parse
m_attrs = get_init_parameters_or_fail(module) m_attrs = get_init_parameters_or_fail(module)
if m_attrs is not None: if m_attrs is not None:
......
...@@ -10,7 +10,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin ...@@ -10,7 +10,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin
from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from ....serializer import serialize_cls import nni
class BypassPlugin(TrainingTypePlugin): class BypassPlugin(TrainingTypePlugin):
...@@ -126,7 +126,7 @@ def get_accelerator_connector( ...@@ -126,7 +126,7 @@ def get_accelerator_connector(
) )
@serialize_cls @nni.trace
class BypassAccelerator(Accelerator): class BypassAccelerator(Accelerator):
def __init__(self, precision_plugin=None, device="cpu", **trainer_kwargs): def __init__(self, precision_plugin=None, device="cpu", **trainer_kwargs):
if precision_plugin is None: if precision_plugin is None:
......
...@@ -14,10 +14,9 @@ import nni ...@@ -14,10 +14,9 @@ import nni
from ..lightning import LightningModule, _AccuracyWithLogits, Lightning from ..lightning import LightningModule, _AccuracyWithLogits, Lightning
from .trainer import Trainer from .trainer import Trainer
from ....serializer import serialize_cls
@serialize_cls @nni.trace
class _MultiModelSupervisedLearningModule(LightningModule): class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric], def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0, n_models: int = 0,
...@@ -126,7 +125,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule): ...@@ -126,7 +125,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer) super().__init__(criterion, metrics, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)
@serialize_cls @nni.trace
class _ClassificationModule(MultiModelSupervisedLearningModule): class _ClassificationModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
...@@ -174,7 +173,7 @@ class Classification(Lightning): ...@@ -174,7 +173,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls @nni.trace
class _RegressionModule(MultiModelSupervisedLearningModule): class _RegressionModule(MultiModelSupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss, def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytorch_lightning as pl import pytorch_lightning as pl
from ....serializer import serialize_cls import nni
from .accelerator import BypassAccelerator from .accelerator import BypassAccelerator
@serialize_cls @nni.trace
class Trainer(pl.Trainer): class Trainer(pl.Trainer):
""" """
Trainer for cross-graph optimization. Trainer for cross-graph optimization.
......
...@@ -10,17 +10,17 @@ import pytorch_lightning as pl ...@@ -10,17 +10,17 @@ import pytorch_lightning as pl
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torchmetrics import torchmetrics
from torch.utils.data import DataLoader import torch.utils.data as torch_data
import nni import nni
from nni.common.serializer import is_traceable
try: try:
from .cgo import trainer as cgo_trainer from .cgo import trainer as cgo_trainer
cgo_import_failed = False cgo_import_failed = False
except ImportError: except ImportError:
cgo_import_failed = True cgo_import_failed = True
from ...graph import Evaluator from nni.retiarii.graph import Evaluator
from ...serializer import serialize_cls
__all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression'] __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classification', 'Regression']
...@@ -40,9 +40,10 @@ class LightningModule(pl.LightningModule): ...@@ -40,9 +40,10 @@ class LightningModule(pl.LightningModule):
self.model = model self.model = model
Trainer = serialize_cls(pl.Trainer) Trainer = nni.trace(pl.Trainer)
DataLoader = serialize_cls(DataLoader) DataLoader = nni.trace(torch_data.DataLoader)
@nni.trace
class Lightning(Evaluator): class Lightning(Evaluator):
""" """
Delegate the whole training to PyTorch Lightning. Delegate the whole training to PyTorch Lightning.
...@@ -74,9 +75,10 @@ class Lightning(Evaluator): ...@@ -74,9 +75,10 @@ class Lightning(Evaluator):
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None): val_dataloaders: Union[DataLoader, List[DataLoader], None] = None):
assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.' assert isinstance(lightning_module, LightningModule), f'Lightning module must be an instance of {__name__}.LightningModule.'
if cgo_import_failed: if cgo_import_failed:
assert isinstance(trainer, Trainer), f'Trainer must be imported from {__name__}' assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), f'Trainer must be imported from {__name__}'
else: else:
assert isinstance(trainer, Trainer) or isinstance(trainer, cgo_trainer.Trainer), \ # this is not isinstance(trainer, Trainer) because with a different trace call, it can be different
assert (isinstance(trainer, pl.Trainer) and is_traceable(trainer)) or isinstance(trainer, cgo_trainer.Trainer), \
f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer' f'Trainer must be imported from {__name__} or nni.retiarii.evaluator.pytorch.cgo.trainer'
assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.' assert _check_dataloader(train_dataloader), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.' assert _check_dataloader(val_dataloaders), f'Wrong dataloader type. Try import DataLoader from {__name__}.'
...@@ -135,7 +137,7 @@ def _check_dataloader(dataloader): ...@@ -135,7 +137,7 @@ def _check_dataloader(dataloader):
return True return True
if isinstance(dataloader, list): if isinstance(dataloader, list):
return all([_check_dataloader(d) for d in dataloader]) return all([_check_dataloader(d) for d in dataloader])
return isinstance(dataloader, DataLoader) return isinstance(dataloader, torch_data.DataLoader) and is_traceable(dataloader)
### The following are some commonly used Lightning modules ### ### The following are some commonly used Lightning modules ###
...@@ -219,7 +221,7 @@ class _AccuracyWithLogits(torchmetrics.Accuracy): ...@@ -219,7 +221,7 @@ class _AccuracyWithLogits(torchmetrics.Accuracy):
return super().update(nn.functional.softmax(pred), target) return super().update(nn.functional.softmax(pred), target)
@serialize_cls @nni.trace
class _ClassificationModule(_SupervisedLearningModule): class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
...@@ -272,7 +274,7 @@ class Classification(Lightning): ...@@ -272,7 +274,7 @@ class Classification(Lightning):
train_dataloader=train_dataloader, val_dataloaders=val_dataloaders) train_dataloader=train_dataloader, val_dataloaders=val_dataloaders)
@serialize_cls @nni.trace
class _RegressionModule(_SupervisedLearningModule): class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss, def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
......
...@@ -200,7 +200,7 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -200,7 +200,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# replace the module with a new instance whose n_models is set # replace the module with a new instance whose n_models is set
# n_models must be set in __init__, otherwise it cannot be captured by serialize_cls # n_models must be set in __init__, otherwise it cannot be captured by serialize_cls
new_module_init_params = model.evaluator.module._init_parameters.copy() new_module_init_params = model.evaluator.module.trace_kwargs.copy()
# MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users # MultiModelSupervisedLearningModule hides n_models of _MultiModelSupervisedLearningModule from users
new_module_init_params['n_models'] = len(multi_model) new_module_init_params['n_models'] = len(multi_model)
......
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
import logging import logging
from typing import Any, Callable from typing import Any, Callable
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType from nni.utils import MetricType
from .graph import MetricData from .graph import MetricData
from .integration_api import register_advisor from .integration_api import register_advisor
from .serializer import json_dumps, json_loads
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -121,7 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -121,7 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'placement_constraint': placement_constraint 'placement_constraint': placement_constraint
} }
_logger.debug('New trial sent: %s', new_trial) _logger.debug('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_dumps(new_trial)) send(CommandType.NewTrialJob, nni.dump(new_trial))
if self.send_trial_callback is not None: if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count return self.parameters_count
...@@ -140,7 +140,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -140,7 +140,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data): def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data) _logger.debug('Trial end: %s', data)
self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED') data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
...@@ -156,7 +156,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -156,7 +156,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@staticmethod @staticmethod
def _process_value(value) -> Any: # hopefully a float def _process_value(value) -> Any: # hopefully a float
value = json_loads(value) value = nni.load(value)
if isinstance(value, dict): if isinstance(value, dict):
if 'default' in value: if 'default' in value:
return value['default'] return value['default']
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import json
from typing import NewType, Any from typing import NewType, Any
import nni import nni
from .serializer import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor # NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import # because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any) RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
...@@ -41,7 +38,6 @@ def receive_trial_parameters() -> dict: ...@@ -41,7 +38,6 @@ def receive_trial_parameters() -> dict:
Reload with our json loads because NNI didn't use Retiarii serializer to load the data. Reload with our json loads because NNI didn't use Retiarii serializer to load the data.
""" """
params = nni.get_next_parameter() params = nni.get_next_parameter()
params = json_loads(json.dumps(params))
return params return params
......
...@@ -8,8 +8,9 @@ from typing import Any, List, Union, Dict, Optional ...@@ -8,8 +8,9 @@ from typing import Any, List, Union, Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...serializer import Translatable, basic_unit from nni.common.serializer import Translatable
from ...utils import NoContextError from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import NoContextError
from .utils import generate_new_label, get_fixed_value from .utils import generate_new_label, get_fixed_value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment