Unverified Commit 443ba8c1 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Serialization infrastructure V2 (#4337)

parent 896c516f
...@@ -6,11 +6,13 @@ from typing import Any, List, Optional, Tuple ...@@ -6,11 +6,13 @@ from typing import Any, List, Optional, Tuple
import torch.nn as nn import torch.nn as nn
from ...mutator import Mutator from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node
from ...graph import Cell, Graph, Model, ModelStatus, Node from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit
from nni.retiarii.utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
from .component import Repeat, NasBench101Cell, NasBench101Mutator from .component import Repeat, NasBench101Cell, NasBench101Mutator
from ...utils import uid
class LayerChoiceMutator(Mutator): class LayerChoiceMutator(Mutator):
...@@ -221,17 +223,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op ...@@ -221,17 +223,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
graph = Graph(model, uid(), '_model', _internal=True)._register() graph = Graph(model, uid(), '_model', _internal=True)._register()
model.python_class = pytorch_model.__class__ model.python_class = pytorch_model.__class__
if len(inspect.signature(model.python_class.__init__).parameters) > 1: if len(inspect.signature(model.python_class.__init__).parameters) > 1:
if not hasattr(pytorch_model, '_init_parameters'): if not getattr(pytorch_model, '_nni_model_wrapper', False):
raise ValueError('Please annotate the model with @serialize decorator in python execution mode ' raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.') 'if your model has init parameters.')
model.python_init_params = pytorch_model._init_parameters model.python_init_params = pytorch_model.trace_kwargs
else: else:
model.python_init_params = {} model.python_init_params = {}
for name, module in pytorch_model.named_modules(): for name, module in pytorch_model.named_modules():
# tricky case: value choice that serves as parameters are stored in _init_parameters # tricky case: value choice that serves as parameters are stored in traced arguments
if hasattr(module, '_init_parameters'): if is_basic_unit(module):
for key, value in module._init_parameters.items(): for key, value in module.trace_kwargs.items():
if isinstance(value, ValueChoice): if isinstance(value, ValueChoice):
node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates}) node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates})
node.label = value.label node.label = value.label
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...serializer import basic_unit from ...serializer import basic_unit
from ...serializer import transparent_serialize
from ...utils import version_larger_equal from ...utils import version_larger_equal
# NOTE: support pytorch version >= 1.5.0 # NOTE: support pytorch version >= 1.5.0
...@@ -42,7 +41,7 @@ if version_larger_equal(torch.__version__, '1.7.0'): ...@@ -42,7 +41,7 @@ if version_larger_equal(torch.__version__, '1.7.0'):
Module = nn.Module Module = nn.Module
Sequential = nn.Sequential Sequential = nn.Sequential
ModuleList = transparent_serialize(nn.ModuleList) ModuleList = basic_unit(nn.ModuleList, basic_unit_tag=False)
Identity = basic_unit(nn.Identity) Identity = basic_unit(nn.Identity)
Linear = basic_unit(nn.Linear) Linear = basic_unit(nn.Linear)
......
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from ...utils import uid, get_current_context from nni.retiarii.utils import ModelNamespace, get_current_context
def generate_new_label(label: Optional[str]): def generate_new_label(label: Optional[str]):
if label is None: if label is None:
return '_mutation_' + str(uid('mutation')) return ModelNamespace.next_label()
return label return label
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import abc
import functools
import inspect import inspect
import types import warnings
from typing import Any from typing import Any, TypeVar, Union
import json_tricks from nni.common.serializer import Traceable, is_traceable, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace
from .utils import get_importable_name, get_module_name, import_, reset_uid __all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
'is_basic_unit', 'is_model_wrapped']
T = TypeVar('T')
def get_init_parameters_or_fail(obj, silently=False):
if hasattr(obj, '_init_parameters'): def get_init_parameters_or_fail(obj: Any):
return obj._init_parameters if is_traceable(obj):
elif silently: return obj.trace_kwargs
return None raise ValueError(f'Object {obj} needs to be serializable but `trace_kwargs` is not available. '
else:
raise ValueError(f'Object {obj} needs to be serializable but `_init_parameters` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. ' 'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. ' 'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), ' 'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use serialize or @serialize_cls.') 'try to use @nni.trace.')
### This is a patch of json-tricks to make it more useful to us ###
def _serialize_class_instance_encode(obj, primitives=False): def serialize(cls, *args, **kwargs):
assert not primitives, 'Encoding with primitives is not supported.' """
try: # FIXME: raise error To create an serializable instance inline without decorator. For example,
if hasattr(obj, '__class__'):
return {
'__type__': get_importable_name(obj.__class__),
'arguments': get_init_parameters_or_fail(obj)
}
except ValueError:
pass
return obj
def _serialize_class_instance_decode(obj):
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
return import_(obj['__type__'])(**obj['arguments'])
return obj
.. code-block:: python
def _type_encode(obj, primitives=False): self.op = serialize(MyCustomOp, hidden_units=128)
assert not primitives, 'Encoding with primitives is not supported.' """
if isinstance(obj, type): warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
return {'__typename__': get_importable_name(obj, relocate_module=True)} 'Try to use nni.trace, e.g., nni.trace(torch.optim.Adam)(learning_rate=1e-4) instead.',
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)): category=DeprecationWarning)
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized. return trace(cls)(*args, **kwargs)
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return {'__typename__': get_importable_name(obj, relocate_module=True)}
return obj
def _type_decode(obj): def serialize_cls(cls):
if isinstance(obj, dict) and '__typename__' in obj: """
return import_(obj['__typename__']) To create an serializable class.
return obj """
warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
'Try to use nni.trace instead.', category=DeprecationWarning)
return trace(cls)
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode]) def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode]) """
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode]) To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
### End of json-tricks patch ### ``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.
.. code-block:: python
class Translatable(abc.ABC): @basic_unit
""" class PrimitiveOp(nn.Module):
Inherit this class and implement ``translate`` when the inner class needs a different ...
parameter from the wrapper class in its init function.
""" """
_check_wrapped(cls)
@abc.abstractmethod import torch.nn as nn
def _translate(self) -> Any: assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
pass
cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag
def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False, stop_parsing=True): # HACK: for torch script
class wrapper(cls): # https://github.com/pytorch/pytorch/pull/45261
def __init__(self, *args, **kwargs): # https://github.com/pytorch/pytorch/issues/54688
self._stop_parsing = stop_parsing # I'm not sure whether there will be potential issues
if reset_mutation_uid: import torch
reset_uid('mutation') cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
if store_init_parameters: cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:] cls.trace_args = torch.jit.unused(cls.trace_args)
full_args = {} cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
full_args.update(kwargs)
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value._translate()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value._translate()
self._init_parameters = full_args
else:
self._init_parameters = {}
super().__init__(*args, **kwargs) return cls
wrapper.__module__ = get_module_name(cls)
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
return wrapper def model_wrapper(cls: T) -> Union[T, Traceable]:
"""
Wrap the model if you are using pure-python execution engine. For example
.. code-block:: python
def serialize_cls(cls): @model_wrapper
""" class MyModel(nn.Module):
To create an serializable class. ...
"""
return _create_wrapper_cls(cls)
The wrapper serves two purposes:
def transparent_serialize(cls): 1. Capture the init parameters of python class so that it can be re-instantiated in another process.
""" 2. Reset uid in ``mutation`` namespace so that each model counts from zero.
Wrap a module but does not record parameters. For internal use only. Can be useful in unittest and other multi-model scenarios.
""" """
return _create_wrapper_cls(cls, store_init_parameters=False) _check_wrapped(cls)
import torch.nn as nn
assert issubclass(cls, nn.Module)
def serialize(cls, *args, **kwargs): wrapper = trace(cls)
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python class reset_wrapper(wrapper):
def __init__(self, *args, **kwargs):
with ModelNamespace():
super().__init__(*args, **kwargs)
self.op = serialize(MyCustomOp, hidden_units=128) _copy_class_wrapper_attributes(wrapper, reset_wrapper)
""" reset_wrapper.__wrapped__ = wrapper.__wrapped__
return serialize_cls(cls)(*args, **kwargs) reset_wrapper._nni_model_wrapper = True
return reset_wrapper
def basic_unit(cls): def is_basic_unit(cls_or_instance) -> bool:
""" if not inspect.isclass(cls_or_instance):
To wrap a module as a basic unit, to stop it from parsing and make it mutate-able. cls_or_instance = cls_or_instance.__class__
""" return getattr(cls_or_instance, '_nni_basic_unit', False)
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
return serialize_cls(cls)
def model_wrapper(cls): def is_model_wrapped(cls_or_instance) -> bool:
""" if not inspect.isclass(cls_or_instance):
Wrap the model if you are using pure-python execution engine. cls_or_instance = cls_or_instance.__class__
return getattr(cls_or_instance, '_nni_model_wrapper', False)
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process. def _check_wrapped(cls: T) -> bool:
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios. if getattr(cls, '_traced', False) or getattr(cls, '_nni_model_wrapper', False):
""" raise TypeError(f'{cls} is already wrapped with trace wrapper (basic_unit / model_wrapper / trace). Cannot wrap again.')
return _create_wrapper_cls(cls, reset_mutation_uid=True, stop_parsing=False)
...@@ -25,6 +25,8 @@ def version_larger_equal(a: str, b: str) -> bool: ...@@ -25,6 +25,8 @@ def version_larger_equal(a: str, b: str) -> bool:
_last_uid = defaultdict(int) _last_uid = defaultdict(int)
_DEFAULT_MODEL_NAMESPACE = 'model'
def uid(namespace: str = 'default') -> int: def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1 _last_uid[namespace] += 1
...@@ -77,6 +79,8 @@ class ContextStack: ...@@ -77,6 +79,8 @@ class ContextStack:
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace. get the corresponding value in the namespace.
Note that this is not multi-processing safe. Also, the values will get cleared for a new process.
""" """
_stack: Dict[str, List[Any]] = defaultdict(list) _stack: Dict[str, List[Any]] = defaultdict(list)
...@@ -107,5 +111,46 @@ class ContextStack: ...@@ -107,5 +111,46 @@ class ContextStack:
return cls._stack[key][-1] return cls._stack[key][-1]
class ModelNamespace:
"""
To create an individual namespace for models to enable automatic numbering.
"""
def __init__(self, key: str = _DEFAULT_MODEL_NAMESPACE):
# for example, key: "model_wrapper"
self.key = key
def __enter__(self):
# For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
# the next thing up is [1, 2, 2, 4].
# `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
try:
current_context = ContextStack.top(self.key)
next_uid = uid(self._simple_name(self.key, current_context))
ContextStack.push(self.key, current_context + [next_uid])
reset_uid(self._simple_name(self.key, current_context + [next_uid]))
except NoContextError:
ContextStack.push(self.key, [])
reset_uid(self._simple_name(self.key, []))
def __exit__(self, *args, **kwargs):
ContextStack.pop(self.key)
@staticmethod
def next_label(key: str = _DEFAULT_MODEL_NAMESPACE) -> str:
try:
current_context = ContextStack.top(key)
except NoContextError:
# fallback to use "default" namespace
return ModelNamespace._simple_name('default', [uid()])
next_uid = uid(ModelNamespace._simple_name(key, current_context))
return ModelNamespace._simple_name(key, current_context + [next_uid])
@staticmethod
def _simple_name(key: str, lst: List[Any]) -> str:
return key + ''.join(['_' + str(k) for k in lst])
def get_current_context(key: str) -> Any: def get_current_context(key: str) -> Any:
return ContextStack.top(key) return ContextStack.top(key)
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import logging import logging
from collections import defaultdict from collections import defaultdict
import json_tricks
from nni import NoMoreTrialError from nni import NoMoreTrialError
from nni.assessor import AssessResult from nni.assessor import AssessResult
...@@ -12,7 +11,8 @@ from .common import multi_thread_enabled, multi_phase_enabled ...@@ -12,7 +11,8 @@ from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
from .msg_dispatcher_base import MsgDispatcherBase from .msg_dispatcher_base import MsgDispatcherBase
from .protocol import CommandType, send from .protocol import CommandType, send
from ..utils import MetricType, to_json from ..common.serializer import dump, load
from ..utils import MetricType
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -63,7 +63,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p ...@@ -63,7 +63,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
ret['parameter_index'] = parameter_index ret['parameter_index'] = parameter_index
else: else:
ret['parameter_index'] = 0 ret['parameter_index'] = 0
return to_json(ret) return dump(ret)
class MsgDispatcher(MsgDispatcherBase): class MsgDispatcher(MsgDispatcherBase):
...@@ -115,8 +115,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -115,8 +115,8 @@ class MsgDispatcher(MsgDispatcherBase):
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value' data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
""" """
for entry in data: for entry in data:
entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value']) entry['value'] = entry['value'] if type(entry['value']) is str else dump(entry['value'])
entry['value'] = json_tricks.loads(entry['value']) entry['value'] = load(entry['value'])
self.tuner.import_data(data) self.tuner.import_data(data)
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
...@@ -133,7 +133,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -133,7 +133,7 @@ class MsgDispatcher(MsgDispatcherBase):
""" """
# metrics value is dumped as json string in trial, so we need to decode it here # metrics value is dumped as json string in trial, so we need to decode it here
if 'value' in data: if 'value' in data:
data['value'] = json_tricks.loads(data['value']) data['value'] = load(data['value'])
if data['type'] == MetricType.FINAL: if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
elif data['type'] == MetricType.PERIODICAL: elif data['type'] == MetricType.PERIODICAL:
...@@ -167,7 +167,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -167,7 +167,7 @@ class MsgDispatcher(MsgDispatcherBase):
if self.assessor is not None: if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED') self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None: if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED') self.tuner.trial_end(load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
def _handle_final_metric_data(self, data): def _handle_final_metric_data(self, data):
"""Call tuner to process final results """Call tuner to process final results
...@@ -221,7 +221,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -221,7 +221,7 @@ class MsgDispatcher(MsgDispatcherBase):
if result is AssessResult.Bad: if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id) _logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id)) send(CommandType.KillTrialJob, dump(trial_job_id))
# notify tuner # notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', _logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]',
dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS) dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
...@@ -239,5 +239,5 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -239,5 +239,5 @@ class MsgDispatcher(MsgDispatcherBase):
if multi_thread_enabled(): if multi_thread_enabled():
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
else: else:
data['value'] = to_json(data['value']) data['value'] = dump(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data) self.enqueue_command(CommandType.ReportMetricData, data)
...@@ -5,10 +5,10 @@ import threading ...@@ -5,10 +5,10 @@ import threading
import logging import logging
from multiprocessing.dummy import Pool as ThreadPool from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty from queue import Queue, Empty
import json_tricks
from .common import multi_thread_enabled from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
from ..common import load
from ..recoverable import Recoverable from ..recoverable import Recoverable
from .protocol import CommandType, receive from .protocol import CommandType, receive
...@@ -50,7 +50,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -50,7 +50,7 @@ class MsgDispatcherBase(Recoverable):
while not self.stopping: while not self.stopping:
command, data = receive() command, data = receive()
if data: if data:
data = json_tricks.loads(data) data = load(data)
if command is None or command is CommandType.Terminate: if command is None or command is CommandType.Terminate:
break break
...@@ -162,7 +162,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -162,7 +162,7 @@ class MsgDispatcherBase(Recoverable):
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
"""The message dispatcher is demanded to generate ``data`` trial jobs. """The message dispatcher is demanded to generate ``data`` trial jobs.
These trial jobs should be sent via ``send(CommandType.NewTrialJob, json_tricks.dumps(parameter))``, These trial jobs should be sent via ``send(CommandType.NewTrialJob, nni.dump(parameter))``,
where ``parameter`` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter". where ``parameter`` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this ``send`` exactly ``data`` times. Semantically, message dispatcher should do this ``send`` exactly ``data`` times.
......
...@@ -3,11 +3,10 @@ ...@@ -3,11 +3,10 @@
import os import os
import sys import sys
import json
import time import time
import subprocess import subprocess
from nni.utils import to_json from nni.common import dump, load
from ..env_vars import trial_env_vars from ..env_vars import trial_env_vars
_sysdir = trial_env_vars.NNI_SYS_DIR _sysdir = trial_env_vars.NNI_SYS_DIR
...@@ -27,7 +26,7 @@ _multiphase = trial_env_vars.MULTI_PHASE ...@@ -27,7 +26,7 @@ _multiphase = trial_env_vars.MULTI_PHASE
_param_index = 0 _param_index = 0
def request_next_parameter(): def request_next_parameter():
metric = to_json({ metric = dump({
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'REQUEST_PARAMETER', 'type': 'REQUEST_PARAMETER',
'sequence': 0, 'sequence': 0,
...@@ -54,7 +53,7 @@ def get_next_parameter(): ...@@ -54,7 +53,7 @@ def get_next_parameter():
while not (os.path.isfile(params_filepath) and os.path.getsize(params_filepath) > 0): while not (os.path.isfile(params_filepath) and os.path.getsize(params_filepath) > 0):
time.sleep(3) time.sleep(3)
params_file = open(params_filepath, 'r') params_file = open(params_filepath, 'r')
params = json.load(params_file) params = load(fp=params_file)
_param_index += 1 _param_index += 1
return params return params
......
...@@ -5,7 +5,8 @@ import logging ...@@ -5,7 +5,8 @@ import logging
import warnings import warnings
import colorama import colorama
import json_tricks from nni.common import load
__all__ = [ __all__ = [
'get_next_parameter', 'get_next_parameter',
...@@ -44,7 +45,7 @@ def get_sequence_id(): ...@@ -44,7 +45,7 @@ def get_sequence_id():
return 0 return 0
def send_metric(string): def send_metric(string):
metric = json_tricks.loads(string) metric = load(string)
if metric['type'] == 'FINAL': if metric['type'] == 'FINAL':
_logger.info('Final result: %s', metric['value']) _logger.info('Final result: %s', metric['value'])
elif metric['type'] == 'PERIODICAL': elif metric['type'] == 'PERIODICAL':
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# pylint: skip-file # pylint: skip-file
import copy import copy
import json_tricks from nni.common import load
_params = None _params = None
...@@ -14,15 +14,19 @@ _last_metric = None ...@@ -14,15 +14,19 @@ _last_metric = None
def get_next_parameter(): def get_next_parameter():
return _params return _params
def get_experiment_id(): def get_experiment_id():
return 'fakeidex' return 'fakeidex'
def get_trial_id(): def get_trial_id():
return 'fakeidtr' return 'fakeidtr'
def get_sequence_id(): def get_sequence_id():
return 0 return 0
def send_metric(string): def send_metric(string):
global _last_metric global _last_metric
_last_metric = string _last_metric = string
...@@ -32,8 +36,9 @@ def init_params(params): ...@@ -32,8 +36,9 @@ def init_params(params):
global _params global _params
_params = copy.deepcopy(params) _params = copy.deepcopy(params)
def get_last_metric(): def get_last_metric():
metrics = json_tricks.loads(_last_metric) metrics = load(_last_metric)
metrics['value'] = json_tricks.loads(metrics['value']) metrics['value'] = load(metrics['value'])
return metrics return metrics
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import os import os
import sqlite3 import sqlite3
import json_tricks import nni
from .constants import NNI_HOME_DIR from .constants import NNI_HOME_DIR
from .common_utils import get_file_lock from .common_utils import get_file_lock
...@@ -95,7 +95,7 @@ class Config: ...@@ -95,7 +95,7 @@ class Config:
'''refresh to get latest config''' '''refresh to get latest config'''
sql = 'select params from ExperimentProfile where id=? order by revision DESC' sql = 'select params from ExperimentProfile where id=? order by revision DESC'
args = (self.experiment_id,) args = (self.experiment_id,)
self.config = config_v0_to_v1(json_tricks.loads(self.conn.cursor().execute(sql, args).fetchone()[0])) self.config = config_v0_to_v1(nni.load(self.conn.cursor().execute(sql, args).fetchone()[0]))
def get_config(self): def get_config(self):
'''get a value according to key''' '''get a value according to key'''
...@@ -159,7 +159,7 @@ class Experiments: ...@@ -159,7 +159,7 @@ class Experiments:
'''save config to local file''' '''save config to local file'''
try: try:
with open(self.experiment_file, 'w') as file: with open(self.experiment_file, 'w') as file:
json_tricks.dump(self.experiments, file, indent=4) nni.dump(self.experiments, file, indent=4)
except IOError as error: except IOError as error:
print('Error:', error) print('Error:', error)
return '' return ''
...@@ -169,7 +169,7 @@ class Experiments: ...@@ -169,7 +169,7 @@ class Experiments:
if os.path.exists(self.experiment_file): if os.path.exists(self.experiment_file):
try: try:
with open(self.experiment_file, 'r') as file: with open(self.experiment_file, 'r') as file:
return json_tricks.load(file) return nni.load(fp=file)
except ValueError: except ValueError:
return {} return {}
return {} return {}
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .utils import to_json from .common.serializer import dump
from .runtime.env_vars import trial_env_vars from .runtime.env_vars import trial_env_vars
from .runtime import platform from .runtime import platform
...@@ -124,12 +124,12 @@ def report_intermediate_result(metric): ...@@ -124,12 +124,12 @@ def report_intermediate_result(metric):
global _intermediate_seq global _intermediate_seq
assert _params or trial_env_vars.NNI_PLATFORM is None, \ assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_intermediate_result' 'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = to_json({ metric = dump({
'parameter_id': _params['parameter_id'] if _params else None, 'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL', 'type': 'PERIODICAL',
'sequence': _intermediate_seq, 'sequence': _intermediate_seq,
'value': to_json(metric) 'value': dump(metric)
}) })
_intermediate_seq += 1 _intermediate_seq += 1
platform.send_metric(metric) platform.send_metric(metric)
...@@ -146,11 +146,11 @@ def report_final_result(metric): ...@@ -146,11 +146,11 @@ def report_final_result(metric):
""" """
assert _params or trial_env_vars.NNI_PLATFORM is None, \ assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result' 'nni.get_next_parameter() needs to be called before report_final_result'
metric = to_json({ metric = dump({
'parameter_id': _params['parameter_id'] if _params else None, 'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL', 'type': 'FINAL',
'sequence': 0, 'sequence': 0,
'value': to_json(metric) 'value': dump(metric)
}) })
platform.send_metric(metric) platform.send_metric(metric)
...@@ -2,17 +2,13 @@ ...@@ -2,17 +2,13 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import copy import copy
import functools
from enum import Enum, unique from enum import Enum, unique
from pathlib import Path from pathlib import Path
import json_tricks
from schema import And from schema import And
from . import parameter_expressions from . import parameter_expressions
to_json = functools.partial(json_tricks.dumps, allow_nan=True)
@unique @unique
class OptimizeMode(Enum): class OptimizeMode(Enum):
"""Optimize Mode class """Optimize Mode class
......
import inspect import inspect
import logging
import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.utils import version_larger_equal from nni.retiarii import basic_unit
_logger = logging.getLogger(__name__) _trace_module_names = [
module_name for module_name in dir(nn)
if module_name not in ['Module', 'ModuleList', 'ModuleDict', 'Sequential'] and
inspect.isclass(getattr(nn, module_name)) and issubclass(getattr(nn, module_name), nn.Module)
]
def wrap_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
original_class.bak_init_for_inject = orig_init
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
self._init_parameters = full_args
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def unwrap_module(wrapped_class):
if hasattr(wrapped_class, 'bak_init_for_inject'):
wrapped_class.__init__ = wrapped_class.bak_init_for_inject
delattr(wrapped_class, 'bak_init_for_inject')
return None
def remove_inject_pytorch_nn(): def remove_inject_pytorch_nn():
Identity = unwrap_module(nn.Identity) for name in _trace_module_names:
Linear = unwrap_module(nn.Linear) if hasattr(getattr(nn, name), '__wrapped__'):
Conv1d = unwrap_module(nn.Conv1d) setattr(nn, name, getattr(nn, name).__wrapped__)
Conv2d = unwrap_module(nn.Conv2d)
Conv3d = unwrap_module(nn.Conv3d)
ConvTranspose1d = unwrap_module(nn.ConvTranspose1d)
ConvTranspose2d = unwrap_module(nn.ConvTranspose2d)
ConvTranspose3d = unwrap_module(nn.ConvTranspose3d)
Threshold = unwrap_module(nn.Threshold)
ReLU = unwrap_module(nn.ReLU)
Hardtanh = unwrap_module(nn.Hardtanh)
ReLU6 = unwrap_module(nn.ReLU6)
Sigmoid = unwrap_module(nn.Sigmoid)
Tanh = unwrap_module(nn.Tanh)
Softmax = unwrap_module(nn.Softmax)
Softmax2d = unwrap_module(nn.Softmax2d)
LogSoftmax = unwrap_module(nn.LogSoftmax)
ELU = unwrap_module(nn.ELU)
SELU = unwrap_module(nn.SELU)
CELU = unwrap_module(nn.CELU)
GLU = unwrap_module(nn.GLU)
GELU = unwrap_module(nn.GELU)
Hardshrink = unwrap_module(nn.Hardshrink)
LeakyReLU = unwrap_module(nn.LeakyReLU)
LogSigmoid = unwrap_module(nn.LogSigmoid)
Softplus = unwrap_module(nn.Softplus)
Softshrink = unwrap_module(nn.Softshrink)
MultiheadAttention = unwrap_module(nn.MultiheadAttention)
PReLU = unwrap_module(nn.PReLU)
Softsign = unwrap_module(nn.Softsign)
Softmin = unwrap_module(nn.Softmin)
Tanhshrink = unwrap_module(nn.Tanhshrink)
RReLU = unwrap_module(nn.RReLU)
AvgPool1d = unwrap_module(nn.AvgPool1d)
AvgPool2d = unwrap_module(nn.AvgPool2d)
AvgPool3d = unwrap_module(nn.AvgPool3d)
MaxPool1d = unwrap_module(nn.MaxPool1d)
MaxPool2d = unwrap_module(nn.MaxPool2d)
MaxPool3d = unwrap_module(nn.MaxPool3d)
MaxUnpool1d = unwrap_module(nn.MaxUnpool1d)
MaxUnpool2d = unwrap_module(nn.MaxUnpool2d)
MaxUnpool3d = unwrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = unwrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = unwrap_module(nn.FractionalMaxPool3d)
LPPool1d = unwrap_module(nn.LPPool1d)
LPPool2d = unwrap_module(nn.LPPool2d)
LocalResponseNorm = unwrap_module(nn.LocalResponseNorm)
BatchNorm1d = unwrap_module(nn.BatchNorm1d)
BatchNorm2d = unwrap_module(nn.BatchNorm2d)
BatchNorm3d = unwrap_module(nn.BatchNorm3d)
InstanceNorm1d = unwrap_module(nn.InstanceNorm1d)
InstanceNorm2d = unwrap_module(nn.InstanceNorm2d)
InstanceNorm3d = unwrap_module(nn.InstanceNorm3d)
LayerNorm = unwrap_module(nn.LayerNorm)
GroupNorm = unwrap_module(nn.GroupNorm)
SyncBatchNorm = unwrap_module(nn.SyncBatchNorm)
Dropout = unwrap_module(nn.Dropout)
Dropout2d = unwrap_module(nn.Dropout2d)
Dropout3d = unwrap_module(nn.Dropout3d)
AlphaDropout = unwrap_module(nn.AlphaDropout)
FeatureAlphaDropout = unwrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = unwrap_module(nn.ReflectionPad1d)
ReflectionPad2d = unwrap_module(nn.ReflectionPad2d)
ReplicationPad2d = unwrap_module(nn.ReplicationPad2d)
ReplicationPad1d = unwrap_module(nn.ReplicationPad1d)
ReplicationPad3d = unwrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = unwrap_module(nn.CrossMapLRN2d)
Embedding = unwrap_module(nn.Embedding)
EmbeddingBag = unwrap_module(nn.EmbeddingBag)
RNNBase = unwrap_module(nn.RNNBase)
RNN = unwrap_module(nn.RNN)
LSTM = unwrap_module(nn.LSTM)
GRU = unwrap_module(nn.GRU)
RNNCellBase = unwrap_module(nn.RNNCellBase)
RNNCell = unwrap_module(nn.RNNCell)
LSTMCell = unwrap_module(nn.LSTMCell)
GRUCell = unwrap_module(nn.GRUCell)
PixelShuffle = unwrap_module(nn.PixelShuffle)
Upsample = unwrap_module(nn.Upsample)
UpsamplingNearest2d = unwrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = unwrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = unwrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = unwrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = unwrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = unwrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = unwrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = unwrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = unwrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = unwrap_module(nn.TripletMarginLoss)
ZeroPad2d = unwrap_module(nn.ZeroPad2d)
ConstantPad1d = unwrap_module(nn.ConstantPad1d)
ConstantPad2d = unwrap_module(nn.ConstantPad2d)
ConstantPad3d = unwrap_module(nn.ConstantPad3d)
Bilinear = unwrap_module(nn.Bilinear)
CosineSimilarity = unwrap_module(nn.CosineSimilarity)
Unfold = unwrap_module(nn.Unfold)
Fold = unwrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = unwrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = unwrap_module(nn.TransformerEncoder)
TransformerDecoder = unwrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = unwrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = unwrap_module(nn.TransformerDecoderLayer)
Transformer = unwrap_module(nn.Transformer)
Flatten = unwrap_module(nn.Flatten)
Hardsigmoid = unwrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = unwrap_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = unwrap_module(nn.SiLU)
Unflatten = unwrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = unwrap_module(nn.TripletMarginWithDistanceLoss)
def inject_pytorch_nn(): def inject_pytorch_nn():
Identity = wrap_module(nn.Identity) for name in _trace_module_names:
Linear = wrap_module(nn.Linear) setattr(nn, name, basic_unit(getattr(nn, name)))
Conv1d = wrap_module(nn.Conv1d)
Conv2d = wrap_module(nn.Conv2d)
Conv3d = wrap_module(nn.Conv3d)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d)
Threshold = wrap_module(nn.Threshold)
ReLU = wrap_module(nn.ReLU)
Hardtanh = wrap_module(nn.Hardtanh)
ReLU6 = wrap_module(nn.ReLU6)
Sigmoid = wrap_module(nn.Sigmoid)
Tanh = wrap_module(nn.Tanh)
Softmax = wrap_module(nn.Softmax)
Softmax2d = wrap_module(nn.Softmax2d)
LogSoftmax = wrap_module(nn.LogSoftmax)
ELU = wrap_module(nn.ELU)
SELU = wrap_module(nn.SELU)
CELU = wrap_module(nn.CELU)
GLU = wrap_module(nn.GLU)
GELU = wrap_module(nn.GELU)
Hardshrink = wrap_module(nn.Hardshrink)
LeakyReLU = wrap_module(nn.LeakyReLU)
LogSigmoid = wrap_module(nn.LogSigmoid)
Softplus = wrap_module(nn.Softplus)
Softshrink = wrap_module(nn.Softshrink)
MultiheadAttention = wrap_module(nn.MultiheadAttention)
PReLU = wrap_module(nn.PReLU)
Softsign = wrap_module(nn.Softsign)
Softmin = wrap_module(nn.Softmin)
Tanhshrink = wrap_module(nn.Tanhshrink)
RReLU = wrap_module(nn.RReLU)
AvgPool1d = wrap_module(nn.AvgPool1d)
AvgPool2d = wrap_module(nn.AvgPool2d)
AvgPool3d = wrap_module(nn.AvgPool3d)
MaxPool1d = wrap_module(nn.MaxPool1d)
MaxPool2d = wrap_module(nn.MaxPool2d)
MaxPool3d = wrap_module(nn.MaxPool3d)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d)
LPPool1d = wrap_module(nn.LPPool1d)
LPPool2d = wrap_module(nn.LPPool2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm)
BatchNorm1d = wrap_module(nn.BatchNorm1d)
BatchNorm2d = wrap_module(nn.BatchNorm2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d)
LayerNorm = wrap_module(nn.LayerNorm)
GroupNorm = wrap_module(nn.GroupNorm)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm)
Dropout = wrap_module(nn.Dropout)
Dropout2d = wrap_module(nn.Dropout2d)
Dropout3d = wrap_module(nn.Dropout3d)
AlphaDropout = wrap_module(nn.AlphaDropout)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d)
Embedding = wrap_module(nn.Embedding)
EmbeddingBag = wrap_module(nn.EmbeddingBag)
RNNBase = wrap_module(nn.RNNBase)
RNN = wrap_module(nn.RNN)
LSTM = wrap_module(nn.LSTM)
GRU = wrap_module(nn.GRU)
RNNCellBase = wrap_module(nn.RNNCellBase)
RNNCell = wrap_module(nn.RNNCell)
LSTMCell = wrap_module(nn.LSTMCell)
GRUCell = wrap_module(nn.GRUCell)
PixelShuffle = wrap_module(nn.PixelShuffle)
Upsample = wrap_module(nn.Upsample)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = wrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
Flatten = wrap_module(nn.Flatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = wrap_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = wrap_module(nn.SiLU)
Unflatten = wrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
import json
import os import os
import threading import threading
import unittest import unittest
...@@ -161,7 +160,7 @@ def _new_trainer(): ...@@ -161,7 +160,7 @@ def _new_trainer():
def _load_mnist(n_models: int = 1): def _load_mnist(n_models: int = 1):
path = Path(__file__).parent / 'mnist_pytorch.json' path = Path(__file__).parent / 'mnist_pytorch.json'
with open(path) as f: with open(path) as f:
mnist_model = Model._load(json.load(f)) mnist_model = Model._load(nni.load(fp=f))
mnist_model.evaluator = _new_trainer() mnist_model.evaluator = _new_trainer()
if n_models == 1: if n_models == 1:
...@@ -176,12 +175,12 @@ def _load_mnist(n_models: int = 1): ...@@ -176,12 +175,12 @@ def _load_mnist(n_models: int = 1):
def _get_final_result(): def _get_final_result():
result = json.loads(nni.runtime.platform.test._last_metric)['value'] result = nni.load(nni.runtime.platform.test._last_metric)['value']
if isinstance(result, list): if isinstance(result, list):
return [float(_) for _ in result] return [float(_) for _ in result]
else: else:
if isinstance(result, str) and '[' in result: if isinstance(result, str) and '[' in result:
return json.loads(result) return nni.load(result)
return [float(result)] return [float(result)]
...@@ -311,7 +310,7 @@ class CGOEngineTest(unittest.TestCase): ...@@ -311,7 +310,7 @@ class CGOEngineTest(unittest.TestCase):
if torch.cuda.is_available() and torch.cuda.device_count() >= 2: if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
cmd, data = protocol.receive() cmd, data = protocol.receive()
params = json.loads(data) params = nni.load(data)
tt.init_params(params) tt.init_params(params)
......
...@@ -50,7 +50,7 @@ class FCNet(nn.Module): ...@@ -50,7 +50,7 @@ class FCNet(nn.Module):
return output.view(-1) return output.view(-1)
@serialize_cls @nni.trace
class DiabetesDataset(Dataset): class DiabetesDataset(Dataset):
def __init__(self, train=True): def __init__(self, train=True):
data = load_diabetes() data = load_diabetes()
......
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
@model_wrapper
class Model(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, 10, 3)
self.conv2 = nn.LayerChoice([
nn.Conv2d(10, 10, 3),
nn.MaxPool2d(3)
])
self.conv3 = nn.LayerChoice([
nn.Identity(),
nn.Conv2d(10, 10, 1)
])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(10, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.avgpool(x).view(x.size(0), -1)
x = self.fc(x)
return x
@model_wrapper
class ModelInner(nn.Module):
def __init__(self):
super().__init__()
self.net1 = nn.LayerChoice([
nn.Linear(10, 10),
nn.Linear(10, 10, bias=False)
])
self.net2 = nn.LayerChoice([
nn.Linear(10, 10),
nn.Linear(10, 10, bias=False)
])
def forward(self, x):
x = self.net1(x)
x = self.net2(x)
return x
@model_wrapper
class ModelNested(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = ModelInner()
self.fc2 = nn.LayerChoice([
nn.Linear(10, 10),
nn.Linear(10, 10, bias=False)
])
self.fc3 = ModelInner()
def forward(self, x):
return self.fc3(self.fc2(self.fc1(x)))
def test_model_wrapper():
model = Model(3)
assert model.trace_symbol == Model.__wrapped__
assert model.trace_kwargs == {'in_channels': 3}
assert model.conv2.label == 'model_1'
assert model.conv3.label == 'model_2'
assert model(torch.randn(1, 3, 5, 5)).size() == torch.Size([1, 1])
model = Model(4)
assert model.trace_symbol == Model.__wrapped__
assert model.conv2.label == 'model_1' # not changed
def test_model_wrapper_nested():
model = ModelNested()
assert model.fc1.net1.label == 'model_1_1'
assert model.fc1.net2.label == 'model_1_2'
assert model.fc2.label == 'model_2'
assert model.fc3.net1.label == 'model_3_1'
assert model.fc3.net2.label == 'model_3_2'
if __name__ == '__main__':
test_model_wrapper_nested()
import json
import math
from pathlib import Path
import re
import sys
import torch
from nni.retiarii import json_dumps, json_loads, serialize
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
sys.path.insert(0, Path(__file__).parent.as_posix())
from imported.model import ImportTest
class Foo:
def __init__(self, a, b=1):
self.aa = a
self.bb = [b + 1 for _ in range(1000)]
def __eq__(self, other):
return self.aa == other.aa and self.bb == other.bb
def test_serialize():
module = serialize(Foo, 3)
assert json_loads(json_dumps(module)) == module
module = serialize(Foo, b=2, a=1)
assert json_loads(json_dumps(module)) == module
module = serialize(Foo, Foo(1), 5)
dumped_module = json_dumps(module)
assert len(dumped_module) > 200 # should not be too longer if the serialization is correct
module = serialize(Foo, serialize(Foo, 1), 5)
dumped_module = json_dumps(module)
assert len(dumped_module) < 200 # should not be too longer if the serialization is correct
assert json_loads(dumped_module) == module
def test_basic_unit():
module = ImportTest(3, 0.5)
assert json_loads(json_dumps(module)) == module
def test_dataset():
dataset = serialize(MNIST, root='data/mnist', train=False, download=True)
dataloader = serialize(DataLoader, dataset, batch_size=10)
dumped_ans = {
"__type__": "torch.utils.data.dataloader.DataLoader",
"arguments": {
"batch_size": 10,
"dataset": {
"__type__": "torchvision.datasets.mnist.MNIST",
"arguments": {"root": "data/mnist", "train": False, "download": True}
}
}
}
assert json_dumps(dataloader) == json_dumps(dumped_ans)
dataloader = json_loads(json_dumps(dumped_ans))
assert isinstance(dataloader, DataLoader)
dataset = serialize(MNIST, root='data/mnist', train=False, download=True,
transform=serialize(
transforms.Compose,
[serialize(transforms.ToTensor), serialize(transforms.Normalize, (0.1307,), (0.3081,))]
))
dataloader = serialize(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
dataset = serialize(MNIST, root='data/mnist', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
dataloader = serialize(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
def test_type():
assert json_dumps(torch.optim.Adam) == '{"__typename__": "torch.optim.adam.Adam"}'
assert json_loads('{"__typename__": "torch.optim.adam.Adam"}') == torch.optim.Adam
assert re.match(r'{"__typename__": "(.*)test_serializer.Foo"}', json_dumps(Foo))
assert json_dumps(math.floor) == '{"__typename__": "math.floor"}'
assert json_loads('{"__typename__": "math.floor"}') == math.floor
if __name__ == '__main__':
test_serialize()
test_basic_unit()
test_dataset()
test_type()
import math
from pathlib import Path
import re
import sys
import nni import nni
import torch import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from nni.common.serializer import is_traceable
if True: # prevent auto formatting
sys.path.insert(0, Path(__file__).parent.as_posix())
from imported.model import ImportTest
@nni.trace @nni.trace
...@@ -23,8 +37,8 @@ def test_simple_class(): ...@@ -23,8 +37,8 @@ def test_simple_class():
assert '"__kwargs__": {"a": 1, "b": 2}' in dump_str assert '"__kwargs__": {"a": 1, "b": 2}' in dump_str
assert '"__symbol__"' in dump_str assert '"__symbol__"' in dump_str
instance = nni.load(dump_str) instance = nni.load(dump_str)
assert instance.get()._a == 1 assert instance._a == 1
assert instance.get()._b == 2 assert instance._b == 2
def test_external_class(): def test_external_class():
...@@ -44,7 +58,7 @@ def test_external_class(): ...@@ -44,7 +58,7 @@ def test_external_class():
r'"__kwargs__": {"in_channels": 3, "out_channels": 16, "kernel_size": 3}}' r'"__kwargs__": {"in_channels": 3, "out_channels": 16, "kernel_size": 3}}'
conv = nni.load(nni.dump(conv)) conv = nni.load(nni.dump(conv))
assert conv.get().kernel_size == (3, 3) assert conv.kernel_size == (3, 3)
def test_nested_class(): def test_nested_class():
...@@ -53,8 +67,8 @@ def test_nested_class(): ...@@ -53,8 +67,8 @@ def test_nested_class():
assert b._a._a == 1 assert b._a._a == 1
dump_str = nni.dump(b) dump_str = nni.dump(b)
b = nni.load(dump_str) b = nni.load(dump_str)
assert repr(b) == 'SerializableObject(type=SimpleClass, a=SerializableObject(type=SimpleClass, a=1, b=2))' assert 'SimpleClass object at' in repr(b)
assert b.get()._a._a == 1 assert b._a._a == 1
def test_unserializable(): def test_unserializable():
...@@ -64,8 +78,137 @@ def test_unserializable(): ...@@ -64,8 +78,137 @@ def test_unserializable():
assert a._a == 1 assert a._a == 1
def test_function():
t = nni.trace(math.sqrt, kw_only=False)(3)
assert 1 < t < 2
assert t.trace_symbol == math.sqrt
assert t.trace_args == [3]
t = nni.load(nni.dump(t))
assert 1 < t < 2
assert not is_traceable(t) # trace not recovered, expected, limitation
def simple_class_factory(bb=3.):
return SimpleClass(1, bb)
t = nni.trace(simple_class_factory)(4)
ts = nni.dump(t)
assert '__kwargs__' in ts
t = nni.load(ts)
assert t._a == 1
assert is_traceable(t)
t = t.trace_copy()
assert is_traceable(t)
assert t.trace_symbol(10)._b == 10
assert t.trace_kwargs['bb'] == 4
assert is_traceable(t.trace_copy())
class Foo:
def __init__(self, a, b=1):
self.aa = a
self.bb = [b + 1 for _ in range(1000)]
def __eq__(self, other):
return self.aa == other.aa and self.bb == other.bb
def test_custom_class():
module = nni.trace(Foo)(3)
assert nni.load(nni.dump(module)) == module
module = nni.trace(Foo)(b=2, a=1)
assert nni.load(nni.dump(module)) == module
module = nni.trace(Foo)(Foo(1), 5)
dumped_module = nni.dump(module)
assert len(dumped_module) > 200 # should not be too longer if the serialization is correct
module = nni.trace(Foo)(nni.trace(Foo)(1), 5)
dumped_module = nni.dump(module)
assert nni.load(dumped_module) == module
class Foo:
def __init__(self, a, b=1):
self.aa = a
self.bb = [b + 1 for _ in range(1000)]
def __eq__(self, other):
return self.aa == other.aa and self.bb == other.bb
def test_basic_unit_and_custom_import():
module = ImportTest(3, 0.5)
ss = nni.dump(module)
assert ss == r'{"__symbol__": "path:imported.model.ImportTest", "__kwargs__": {"foo": 3, "bar": 0.5}}'
assert nni.load(nni.dump(module)) == module
import nni.retiarii.nn.pytorch as nn
module = nn.Conv2d(3, 10, 3, bias=False)
ss = nni.dump(module)
assert ss == r'{"__symbol__": "path:torch.nn.modules.conv.Conv2d", "__kwargs__": {"in_channels": 3, "out_channels": 10, "kernel_size": 3, "bias": false}}'
assert nni.load(ss).bias is None
def test_dataset():
dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True)
dataloader = nni.trace(DataLoader)(dataset, batch_size=10)
dumped_ans = {
"__symbol__": "path:torch.utils.data.dataloader.DataLoader",
"__kwargs__": {
"dataset": {
"__symbol__": "path:torchvision.datasets.mnist.MNIST",
"__kwargs__": {"root": "data/mnist", "train": False, "download": True}
},
"batch_size": 10
}
}
print(nni.dump(dataloader))
print(nni.dump(dumped_ans))
assert nni.dump(dataloader) == nni.dump(dumped_ans)
dataloader = nni.load(nni.dump(dumped_ans))
assert isinstance(dataloader, DataLoader)
dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True,
transform=nni.trace(transforms.Compose)([
nni.trace(transforms.ToTensor)(),
nni.trace(transforms.Normalize)((0.1307,), (0.3081,))
]))
dataloader = nni.trace(DataLoader)(dataset, batch_size=10)
x, y = next(iter(nni.load(nni.dump(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True,
transform=nni.trace(transforms.Compose)(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
))
dataloader = nni.trace(DataLoader)(dataset, batch_size=10)
x, y = next(iter(nni.load(nni.dump(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
def test_type():
assert nni.dump(torch.optim.Adam) == '{"__nni_type__": "path:torch.optim.adam.Adam"}'
assert nni.load('{"__nni_type__": "path:torch.optim.adam.Adam"}') == torch.optim.Adam
assert Foo == nni.load(nni.dump(Foo))
assert nni.dump(math.floor) == '{"__nni_type__": "path:math.floor"}'
assert nni.load('{"__nni_type__": "path:math.floor"}') == math.floor
def test_lightning_earlystop():
import nni.retiarii.evaluator.pytorch.lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
trainer = pl.Trainer(callbacks=[nni.trace(EarlyStopping)(monitor="val_loss")])
trainer = nni.load(nni.dump(trainer))
assert any(isinstance(callback, EarlyStopping) for callback in trainer.callbacks)
if __name__ == '__main__': if __name__ == '__main__':
test_simple_class() # test_simple_class()
test_external_class() # test_external_class()
test_nested_class() # test_nested_class()
test_unserializable() # test_unserializable()
# test_basic_unit()
test_type()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment