Unverified Commit 443ba8c1 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Serialization infrastructure V2 (#4337)

parent 896c516f
......@@ -6,11 +6,13 @@ from typing import Any, List, Optional, Tuple
import torch.nn as nn
from ...mutator import Mutator
from ...graph import Cell, Graph, Model, ModelStatus, Node
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node
from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit
from nni.retiarii.utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
from .component import Repeat, NasBench101Cell, NasBench101Mutator
from ...utils import uid
class LayerChoiceMutator(Mutator):
......@@ -221,17 +223,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
graph = Graph(model, uid(), '_model', _internal=True)._register()
model.python_class = pytorch_model.__class__
if len(inspect.signature(model.python_class.__init__).parameters) > 1:
if not hasattr(pytorch_model, '_init_parameters'):
raise ValueError('Please annotate the model with @serialize decorator in python execution mode '
if not getattr(pytorch_model, '_nni_model_wrapper', False):
raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.')
model.python_init_params = pytorch_model._init_parameters
model.python_init_params = pytorch_model.trace_kwargs
else:
model.python_init_params = {}
for name, module in pytorch_model.named_modules():
# tricky case: value choice that serves as parameters are stored in _init_parameters
if hasattr(module, '_init_parameters'):
for key, value in module._init_parameters.items():
# tricky case: value choice that serves as parameters are stored in traced arguments
if is_basic_unit(module):
for key, value in module.trace_kwargs.items():
if isinstance(value, ValueChoice):
node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates})
node.label = value.label
......
......@@ -5,7 +5,6 @@ import torch
import torch.nn as nn
from ...serializer import basic_unit
from ...serializer import transparent_serialize
from ...utils import version_larger_equal
# NOTE: support pytorch version >= 1.5.0
......@@ -42,7 +41,7 @@ if version_larger_equal(torch.__version__, '1.7.0'):
Module = nn.Module
Sequential = nn.Sequential
ModuleList = transparent_serialize(nn.ModuleList)
ModuleList = basic_unit(nn.ModuleList, basic_unit_tag=False)
Identity = basic_unit(nn.Identity)
Linear = basic_unit(nn.Linear)
......
from typing import Any, Optional, Tuple
from ...utils import uid, get_current_context
from nni.retiarii.utils import ModelNamespace, get_current_context
def generate_new_label(label: Optional[str]):
if label is None:
return '_mutation_' + str(uid('mutation'))
return ModelNamespace.next_label()
return label
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
import functools
import inspect
import types
from typing import Any
import warnings
from typing import Any, TypeVar, Union
import json_tricks
from nni.common.serializer import Traceable, is_traceable, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace
from .utils import get_importable_name, get_module_name, import_, reset_uid
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
'is_basic_unit', 'is_model_wrapped']
T = TypeVar('T')
def get_init_parameters_or_fail(obj, silently=False):
if hasattr(obj, '_init_parameters'):
return obj._init_parameters
elif silently:
return None
else:
raise ValueError(f'Object {obj} needs to be serializable but `_init_parameters` is not available. '
def get_init_parameters_or_fail(obj: Any):
if is_traceable(obj):
return obj.trace_kwargs
raise ValueError(f'Object {obj} needs to be serializable but `trace_kwargs` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use serialize or @serialize_cls.')
### This is a patch of json-tricks to make it more useful to us ###
'try to use @nni.trace.')
def _serialize_class_instance_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
try: # FIXME: raise error
if hasattr(obj, '__class__'):
return {
'__type__': get_importable_name(obj.__class__),
'arguments': get_init_parameters_or_fail(obj)
}
except ValueError:
pass
return obj
def _serialize_class_instance_decode(obj):
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
return import_(obj['__type__'])(**obj['arguments'])
return obj
def serialize(cls, *args, **kwargs):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type):
return {'__typename__': get_importable_name(obj, relocate_module=True)}
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return {'__typename__': get_importable_name(obj, relocate_module=True)}
return obj
self.op = serialize(MyCustomOp, hidden_units=128)
"""
warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
'Try to use nni.trace, e.g., nni.trace(torch.optim.Adam)(learning_rate=1e-4) instead.',
category=DeprecationWarning)
return trace(cls)(*args, **kwargs)
def _type_decode(obj):
if isinstance(obj, dict) and '__typename__' in obj:
return import_(obj['__typename__'])
return obj
def serialize_cls(cls):
"""
To create an serializable class.
"""
warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
'Try to use nni.trace instead.', category=DeprecationWarning)
return trace(cls)
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
"""
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
### End of json-tricks patch ###
``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.
.. code-block:: python
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
@basic_unit
class PrimitiveOp(nn.Module):
...
"""
_check_wrapped(cls)
@abc.abstractmethod
def _translate(self) -> Any:
pass
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag
def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False, stop_parsing=True):
class wrapper(cls):
def __init__(self, *args, **kwargs):
self._stop_parsing = stop_parsing
if reset_mutation_uid:
reset_uid('mutation')
if store_init_parameters:
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {}
full_args.update(kwargs)
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value._translate()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value._translate()
self._init_parameters = full_args
else:
self._init_parameters = {}
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
super().__init__(*args, **kwargs)
return cls
wrapper.__module__ = get_module_name(cls)
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
return wrapper
def model_wrapper(cls: T) -> Union[T, Traceable]:
"""
Wrap the model if you are using pure-python execution engine. For example
.. code-block:: python
def serialize_cls(cls):
"""
To create an serializable class.
"""
return _create_wrapper_cls(cls)
@model_wrapper
class MyModel(nn.Module):
...
The wrapper serves two purposes:
def transparent_serialize(cls):
"""
Wrap a module but does not record parameters. For internal use only.
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in ``mutation`` namespace so that each model counts from zero.
Can be useful in unittest and other multi-model scenarios.
"""
return _create_wrapper_cls(cls, store_init_parameters=False)
_check_wrapped(cls)
import torch.nn as nn
assert issubclass(cls, nn.Module)
def serialize(cls, *args, **kwargs):
"""
To create an serializable instance inline without decorator. For example,
wrapper = trace(cls)
.. code-block:: python
class reset_wrapper(wrapper):
def __init__(self, *args, **kwargs):
with ModelNamespace():
super().__init__(*args, **kwargs)
self.op = serialize(MyCustomOp, hidden_units=128)
"""
return serialize_cls(cls)(*args, **kwargs)
_copy_class_wrapper_attributes(wrapper, reset_wrapper)
reset_wrapper.__wrapped__ = wrapper.__wrapped__
reset_wrapper._nni_model_wrapper = True
return reset_wrapper
def basic_unit(cls):
"""
To wrap a module as a basic unit, to stop it from parsing and make it mutate-able.
"""
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
return serialize_cls(cls)
def is_basic_unit(cls_or_instance) -> bool:
if not inspect.isclass(cls_or_instance):
cls_or_instance = cls_or_instance.__class__
return getattr(cls_or_instance, '_nni_basic_unit', False)
def model_wrapper(cls):
"""
Wrap the model if you are using pure-python execution engine.
def is_model_wrapped(cls_or_instance) -> bool:
if not inspect.isclass(cls_or_instance):
cls_or_instance = cls_or_instance.__class__
return getattr(cls_or_instance, '_nni_model_wrapper', False)
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
"""
return _create_wrapper_cls(cls, reset_mutation_uid=True, stop_parsing=False)
def _check_wrapped(cls: T) -> bool:
if getattr(cls, '_traced', False) or getattr(cls, '_nni_model_wrapper', False):
raise TypeError(f'{cls} is already wrapped with trace wrapper (basic_unit / model_wrapper / trace). Cannot wrap again.')
......@@ -25,6 +25,8 @@ def version_larger_equal(a: str, b: str) -> bool:
_last_uid = defaultdict(int)
_DEFAULT_MODEL_NAMESPACE = 'model'
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
......@@ -77,6 +79,8 @@ class ContextStack:
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace.
Note that this is not multi-processing safe. Also, the values will get cleared for a new process.
"""
_stack: Dict[str, List[Any]] = defaultdict(list)
......@@ -107,5 +111,46 @@ class ContextStack:
return cls._stack[key][-1]
class ModelNamespace:
"""
To create an individual namespace for models to enable automatic numbering.
"""
def __init__(self, key: str = _DEFAULT_MODEL_NAMESPACE):
# for example, key: "model_wrapper"
self.key = key
def __enter__(self):
# For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
# the next thing up is [1, 2, 2, 4].
# `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
try:
current_context = ContextStack.top(self.key)
next_uid = uid(self._simple_name(self.key, current_context))
ContextStack.push(self.key, current_context + [next_uid])
reset_uid(self._simple_name(self.key, current_context + [next_uid]))
except NoContextError:
ContextStack.push(self.key, [])
reset_uid(self._simple_name(self.key, []))
def __exit__(self, *args, **kwargs):
ContextStack.pop(self.key)
@staticmethod
def next_label(key: str = _DEFAULT_MODEL_NAMESPACE) -> str:
try:
current_context = ContextStack.top(key)
except NoContextError:
# fallback to use "default" namespace
return ModelNamespace._simple_name('default', [uid()])
next_uid = uid(ModelNamespace._simple_name(key, current_context))
return ModelNamespace._simple_name(key, current_context + [next_uid])
@staticmethod
def _simple_name(key: str, lst: List[Any]) -> str:
return key + ''.join(['_' + str(k) for k in lst])
def get_current_context(key: str) -> Any:
return ContextStack.top(key)
......@@ -3,7 +3,6 @@
import logging
from collections import defaultdict
import json_tricks
from nni import NoMoreTrialError
from nni.assessor import AssessResult
......@@ -12,7 +11,8 @@ from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars
from .msg_dispatcher_base import MsgDispatcherBase
from .protocol import CommandType, send
from ..utils import MetricType, to_json
from ..common.serializer import dump, load
from ..utils import MetricType
_logger = logging.getLogger(__name__)
......@@ -63,7 +63,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return to_json(ret)
return dump(ret)
class MsgDispatcher(MsgDispatcherBase):
......@@ -115,8 +115,8 @@ class MsgDispatcher(MsgDispatcherBase):
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for entry in data:
entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value'])
entry['value'] = json_tricks.loads(entry['value'])
entry['value'] = entry['value'] if type(entry['value']) is str else dump(entry['value'])
entry['value'] = load(entry['value'])
self.tuner.import_data(data)
def handle_add_customized_trial(self, data):
......@@ -133,7 +133,7 @@ class MsgDispatcher(MsgDispatcherBase):
"""
# metrics value is dumped as json string in trial, so we need to decode it here
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
data['value'] = load(data['value'])
if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data)
elif data['type'] == MetricType.PERIODICAL:
......@@ -167,7 +167,7 @@ class MsgDispatcher(MsgDispatcherBase):
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
self.tuner.trial_end(load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
def _handle_final_metric_data(self, data):
"""Call tuner to process final results
......@@ -221,7 +221,7 @@ class MsgDispatcher(MsgDispatcherBase):
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
send(CommandType.KillTrialJob, dump(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]',
dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
......@@ -239,5 +239,5 @@ class MsgDispatcher(MsgDispatcherBase):
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
data['value'] = to_json(data['value'])
data['value'] = dump(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)
......@@ -5,10 +5,10 @@ import threading
import logging
from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty
import json_tricks
from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
from ..common import load
from ..recoverable import Recoverable
from .protocol import CommandType, receive
......@@ -50,7 +50,7 @@ class MsgDispatcherBase(Recoverable):
while not self.stopping:
command, data = receive()
if data:
data = json_tricks.loads(data)
data = load(data)
if command is None or command is CommandType.Terminate:
break
......@@ -162,7 +162,7 @@ class MsgDispatcherBase(Recoverable):
def handle_request_trial_jobs(self, data):
"""The message dispatcher is demanded to generate ``data`` trial jobs.
These trial jobs should be sent via ``send(CommandType.NewTrialJob, json_tricks.dumps(parameter))``,
These trial jobs should be sent via ``send(CommandType.NewTrialJob, nni.dump(parameter))``,
where ``parameter`` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this ``send`` exactly ``data`` times.
......
......@@ -3,11 +3,10 @@
import os
import sys
import json
import time
import subprocess
from nni.utils import to_json
from nni.common import dump, load
from ..env_vars import trial_env_vars
_sysdir = trial_env_vars.NNI_SYS_DIR
......@@ -27,7 +26,7 @@ _multiphase = trial_env_vars.MULTI_PHASE
_param_index = 0
def request_next_parameter():
metric = to_json({
metric = dump({
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'REQUEST_PARAMETER',
'sequence': 0,
......@@ -54,7 +53,7 @@ def get_next_parameter():
while not (os.path.isfile(params_filepath) and os.path.getsize(params_filepath) > 0):
time.sleep(3)
params_file = open(params_filepath, 'r')
params = json.load(params_file)
params = load(fp=params_file)
_param_index += 1
return params
......
......@@ -5,7 +5,8 @@ import logging
import warnings
import colorama
import json_tricks
from nni.common import load
__all__ = [
'get_next_parameter',
......@@ -44,7 +45,7 @@ def get_sequence_id():
return 0
def send_metric(string):
metric = json_tricks.loads(string)
metric = load(string)
if metric['type'] == 'FINAL':
_logger.info('Final result: %s', metric['value'])
elif metric['type'] == 'PERIODICAL':
......
......@@ -4,7 +4,7 @@
# pylint: skip-file
import copy
import json_tricks
from nni.common import load
_params = None
......@@ -14,15 +14,19 @@ _last_metric = None
def get_next_parameter():
return _params
def get_experiment_id():
return 'fakeidex'
def get_trial_id():
return 'fakeidtr'
def get_sequence_id():
return 0
def send_metric(string):
global _last_metric
_last_metric = string
......@@ -32,8 +36,9 @@ def init_params(params):
global _params
_params = copy.deepcopy(params)
def get_last_metric():
metrics = json_tricks.loads(_last_metric)
metrics['value'] = json_tricks.loads(metrics['value'])
metrics = load(_last_metric)
metrics['value'] = load(metrics['value'])
return metrics
......@@ -3,7 +3,7 @@
import os
import sqlite3
import json_tricks
import nni
from .constants import NNI_HOME_DIR
from .common_utils import get_file_lock
......@@ -95,7 +95,7 @@ class Config:
'''refresh to get latest config'''
sql = 'select params from ExperimentProfile where id=? order by revision DESC'
args = (self.experiment_id,)
self.config = config_v0_to_v1(json_tricks.loads(self.conn.cursor().execute(sql, args).fetchone()[0]))
self.config = config_v0_to_v1(nni.load(self.conn.cursor().execute(sql, args).fetchone()[0]))
def get_config(self):
'''get a value according to key'''
......@@ -159,7 +159,7 @@ class Experiments:
'''save config to local file'''
try:
with open(self.experiment_file, 'w') as file:
json_tricks.dump(self.experiments, file, indent=4)
nni.dump(self.experiments, file, indent=4)
except IOError as error:
print('Error:', error)
return ''
......@@ -169,7 +169,7 @@ class Experiments:
if os.path.exists(self.experiment_file):
try:
with open(self.experiment_file, 'r') as file:
return json_tricks.load(file)
return nni.load(fp=file)
except ValueError:
return {}
return {}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import to_json
from .common.serializer import dump
from .runtime.env_vars import trial_env_vars
from .runtime import platform
......@@ -124,12 +124,12 @@ def report_intermediate_result(metric):
global _intermediate_seq
assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_intermediate_result'
metric = to_json({
metric = dump({
'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL',
'sequence': _intermediate_seq,
'value': to_json(metric)
'value': dump(metric)
})
_intermediate_seq += 1
platform.send_metric(metric)
......@@ -146,11 +146,11 @@ def report_final_result(metric):
"""
assert _params or trial_env_vars.NNI_PLATFORM is None, \
'nni.get_next_parameter() needs to be called before report_final_result'
metric = to_json({
metric = dump({
'parameter_id': _params['parameter_id'] if _params else None,
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL',
'sequence': 0,
'value': to_json(metric)
'value': dump(metric)
})
platform.send_metric(metric)
......@@ -2,17 +2,13 @@
# Licensed under the MIT license.
import copy
import functools
from enum import Enum, unique
from pathlib import Path
import json_tricks
from schema import And
from . import parameter_expressions
to_json = functools.partial(json_tricks.dumps, allow_nan=True)
@unique
class OptimizeMode(Enum):
"""Optimize Mode class
......
import inspect
import logging
import torch
import torch.nn as nn
from nni.retiarii.utils import version_larger_equal
from nni.retiarii import basic_unit
_logger = logging.getLogger(__name__)
_trace_module_names = [
module_name for module_name in dir(nn)
if module_name not in ['Module', 'ModuleList', 'ModuleDict', 'Sequential'] and
inspect.isclass(getattr(nn, module_name)) and issubclass(getattr(nn, module_name), nn.Module)
]
def wrap_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
original_class.bak_init_for_inject = orig_init
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
self._init_parameters = full_args
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def unwrap_module(wrapped_class):
if hasattr(wrapped_class, 'bak_init_for_inject'):
wrapped_class.__init__ = wrapped_class.bak_init_for_inject
delattr(wrapped_class, 'bak_init_for_inject')
return None
def remove_inject_pytorch_nn():
Identity = unwrap_module(nn.Identity)
Linear = unwrap_module(nn.Linear)
Conv1d = unwrap_module(nn.Conv1d)
Conv2d = unwrap_module(nn.Conv2d)
Conv3d = unwrap_module(nn.Conv3d)
ConvTranspose1d = unwrap_module(nn.ConvTranspose1d)
ConvTranspose2d = unwrap_module(nn.ConvTranspose2d)
ConvTranspose3d = unwrap_module(nn.ConvTranspose3d)
Threshold = unwrap_module(nn.Threshold)
ReLU = unwrap_module(nn.ReLU)
Hardtanh = unwrap_module(nn.Hardtanh)
ReLU6 = unwrap_module(nn.ReLU6)
Sigmoid = unwrap_module(nn.Sigmoid)
Tanh = unwrap_module(nn.Tanh)
Softmax = unwrap_module(nn.Softmax)
Softmax2d = unwrap_module(nn.Softmax2d)
LogSoftmax = unwrap_module(nn.LogSoftmax)
ELU = unwrap_module(nn.ELU)
SELU = unwrap_module(nn.SELU)
CELU = unwrap_module(nn.CELU)
GLU = unwrap_module(nn.GLU)
GELU = unwrap_module(nn.GELU)
Hardshrink = unwrap_module(nn.Hardshrink)
LeakyReLU = unwrap_module(nn.LeakyReLU)
LogSigmoid = unwrap_module(nn.LogSigmoid)
Softplus = unwrap_module(nn.Softplus)
Softshrink = unwrap_module(nn.Softshrink)
MultiheadAttention = unwrap_module(nn.MultiheadAttention)
PReLU = unwrap_module(nn.PReLU)
Softsign = unwrap_module(nn.Softsign)
Softmin = unwrap_module(nn.Softmin)
Tanhshrink = unwrap_module(nn.Tanhshrink)
RReLU = unwrap_module(nn.RReLU)
AvgPool1d = unwrap_module(nn.AvgPool1d)
AvgPool2d = unwrap_module(nn.AvgPool2d)
AvgPool3d = unwrap_module(nn.AvgPool3d)
MaxPool1d = unwrap_module(nn.MaxPool1d)
MaxPool2d = unwrap_module(nn.MaxPool2d)
MaxPool3d = unwrap_module(nn.MaxPool3d)
MaxUnpool1d = unwrap_module(nn.MaxUnpool1d)
MaxUnpool2d = unwrap_module(nn.MaxUnpool2d)
MaxUnpool3d = unwrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = unwrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = unwrap_module(nn.FractionalMaxPool3d)
LPPool1d = unwrap_module(nn.LPPool1d)
LPPool2d = unwrap_module(nn.LPPool2d)
LocalResponseNorm = unwrap_module(nn.LocalResponseNorm)
BatchNorm1d = unwrap_module(nn.BatchNorm1d)
BatchNorm2d = unwrap_module(nn.BatchNorm2d)
BatchNorm3d = unwrap_module(nn.BatchNorm3d)
InstanceNorm1d = unwrap_module(nn.InstanceNorm1d)
InstanceNorm2d = unwrap_module(nn.InstanceNorm2d)
InstanceNorm3d = unwrap_module(nn.InstanceNorm3d)
LayerNorm = unwrap_module(nn.LayerNorm)
GroupNorm = unwrap_module(nn.GroupNorm)
SyncBatchNorm = unwrap_module(nn.SyncBatchNorm)
Dropout = unwrap_module(nn.Dropout)
Dropout2d = unwrap_module(nn.Dropout2d)
Dropout3d = unwrap_module(nn.Dropout3d)
AlphaDropout = unwrap_module(nn.AlphaDropout)
FeatureAlphaDropout = unwrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = unwrap_module(nn.ReflectionPad1d)
ReflectionPad2d = unwrap_module(nn.ReflectionPad2d)
ReplicationPad2d = unwrap_module(nn.ReplicationPad2d)
ReplicationPad1d = unwrap_module(nn.ReplicationPad1d)
ReplicationPad3d = unwrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = unwrap_module(nn.CrossMapLRN2d)
Embedding = unwrap_module(nn.Embedding)
EmbeddingBag = unwrap_module(nn.EmbeddingBag)
RNNBase = unwrap_module(nn.RNNBase)
RNN = unwrap_module(nn.RNN)
LSTM = unwrap_module(nn.LSTM)
GRU = unwrap_module(nn.GRU)
RNNCellBase = unwrap_module(nn.RNNCellBase)
RNNCell = unwrap_module(nn.RNNCell)
LSTMCell = unwrap_module(nn.LSTMCell)
GRUCell = unwrap_module(nn.GRUCell)
PixelShuffle = unwrap_module(nn.PixelShuffle)
Upsample = unwrap_module(nn.Upsample)
UpsamplingNearest2d = unwrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = unwrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = unwrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = unwrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = unwrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = unwrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = unwrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = unwrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = unwrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = unwrap_module(nn.TripletMarginLoss)
ZeroPad2d = unwrap_module(nn.ZeroPad2d)
ConstantPad1d = unwrap_module(nn.ConstantPad1d)
ConstantPad2d = unwrap_module(nn.ConstantPad2d)
ConstantPad3d = unwrap_module(nn.ConstantPad3d)
Bilinear = unwrap_module(nn.Bilinear)
CosineSimilarity = unwrap_module(nn.CosineSimilarity)
Unfold = unwrap_module(nn.Unfold)
Fold = unwrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = unwrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = unwrap_module(nn.TransformerEncoder)
TransformerDecoder = unwrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = unwrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = unwrap_module(nn.TransformerDecoderLayer)
Transformer = unwrap_module(nn.Transformer)
Flatten = unwrap_module(nn.Flatten)
Hardsigmoid = unwrap_module(nn.Hardsigmoid)
for name in _trace_module_names:
if hasattr(getattr(nn, name), '__wrapped__'):
setattr(nn, name, getattr(nn, name).__wrapped__)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = unwrap_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = unwrap_module(nn.SiLU)
Unflatten = unwrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = unwrap_module(nn.TripletMarginWithDistanceLoss)
def inject_pytorch_nn():
Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear)
Conv1d = wrap_module(nn.Conv1d)
Conv2d = wrap_module(nn.Conv2d)
Conv3d = wrap_module(nn.Conv3d)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d)
Threshold = wrap_module(nn.Threshold)
ReLU = wrap_module(nn.ReLU)
Hardtanh = wrap_module(nn.Hardtanh)
ReLU6 = wrap_module(nn.ReLU6)
Sigmoid = wrap_module(nn.Sigmoid)
Tanh = wrap_module(nn.Tanh)
Softmax = wrap_module(nn.Softmax)
Softmax2d = wrap_module(nn.Softmax2d)
LogSoftmax = wrap_module(nn.LogSoftmax)
ELU = wrap_module(nn.ELU)
SELU = wrap_module(nn.SELU)
CELU = wrap_module(nn.CELU)
GLU = wrap_module(nn.GLU)
GELU = wrap_module(nn.GELU)
Hardshrink = wrap_module(nn.Hardshrink)
LeakyReLU = wrap_module(nn.LeakyReLU)
LogSigmoid = wrap_module(nn.LogSigmoid)
Softplus = wrap_module(nn.Softplus)
Softshrink = wrap_module(nn.Softshrink)
MultiheadAttention = wrap_module(nn.MultiheadAttention)
PReLU = wrap_module(nn.PReLU)
Softsign = wrap_module(nn.Softsign)
Softmin = wrap_module(nn.Softmin)
Tanhshrink = wrap_module(nn.Tanhshrink)
RReLU = wrap_module(nn.RReLU)
AvgPool1d = wrap_module(nn.AvgPool1d)
AvgPool2d = wrap_module(nn.AvgPool2d)
AvgPool3d = wrap_module(nn.AvgPool3d)
MaxPool1d = wrap_module(nn.MaxPool1d)
MaxPool2d = wrap_module(nn.MaxPool2d)
MaxPool3d = wrap_module(nn.MaxPool3d)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d)
LPPool1d = wrap_module(nn.LPPool1d)
LPPool2d = wrap_module(nn.LPPool2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm)
BatchNorm1d = wrap_module(nn.BatchNorm1d)
BatchNorm2d = wrap_module(nn.BatchNorm2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d)
LayerNorm = wrap_module(nn.LayerNorm)
GroupNorm = wrap_module(nn.GroupNorm)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm)
Dropout = wrap_module(nn.Dropout)
Dropout2d = wrap_module(nn.Dropout2d)
Dropout3d = wrap_module(nn.Dropout3d)
AlphaDropout = wrap_module(nn.AlphaDropout)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d)
Embedding = wrap_module(nn.Embedding)
EmbeddingBag = wrap_module(nn.EmbeddingBag)
RNNBase = wrap_module(nn.RNNBase)
RNN = wrap_module(nn.RNN)
LSTM = wrap_module(nn.LSTM)
GRU = wrap_module(nn.GRU)
RNNCellBase = wrap_module(nn.RNNCellBase)
RNNCell = wrap_module(nn.RNNCell)
LSTMCell = wrap_module(nn.LSTMCell)
GRUCell = wrap_module(nn.GRUCell)
PixelShuffle = wrap_module(nn.PixelShuffle)
Upsample = wrap_module(nn.Upsample)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = wrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
Flatten = wrap_module(nn.Flatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = wrap_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = wrap_module(nn.SiLU)
Unflatten = wrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
for name in _trace_module_names:
setattr(nn, name, basic_unit(getattr(nn, name)))
import json
import os
import threading
import unittest
......@@ -161,7 +160,7 @@ def _new_trainer():
def _load_mnist(n_models: int = 1):
path = Path(__file__).parent / 'mnist_pytorch.json'
with open(path) as f:
mnist_model = Model._load(json.load(f))
mnist_model = Model._load(nni.load(fp=f))
mnist_model.evaluator = _new_trainer()
if n_models == 1:
......@@ -176,12 +175,12 @@ def _load_mnist(n_models: int = 1):
def _get_final_result():
result = json.loads(nni.runtime.platform.test._last_metric)['value']
result = nni.load(nni.runtime.platform.test._last_metric)['value']
if isinstance(result, list):
return [float(_) for _ in result]
else:
if isinstance(result, str) and '[' in result:
return json.loads(result)
return nni.load(result)
return [float(result)]
......@@ -311,7 +310,7 @@ class CGOEngineTest(unittest.TestCase):
if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
cmd, data = protocol.receive()
params = json.loads(data)
params = nni.load(data)
tt.init_params(params)
......
......@@ -50,7 +50,7 @@ class FCNet(nn.Module):
return output.view(-1)
@serialize_cls
@nni.trace
class DiabetesDataset(Dataset):
def __init__(self, train=True):
data = load_diabetes()
......
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
@model_wrapper
class Model(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, 10, 3)
self.conv2 = nn.LayerChoice([
nn.Conv2d(10, 10, 3),
nn.MaxPool2d(3)
])
self.conv3 = nn.LayerChoice([
nn.Identity(),
nn.Conv2d(10, 10, 1)
])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(10, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.avgpool(x).view(x.size(0), -1)
x = self.fc(x)
return x
@model_wrapper
class ModelInner(nn.Module):
def __init__(self):
super().__init__()
self.net1 = nn.LayerChoice([
nn.Linear(10, 10),
nn.Linear(10, 10, bias=False)
])
self.net2 = nn.LayerChoice([
nn.Linear(10, 10),
nn.Linear(10, 10, bias=False)
])
def forward(self, x):
x = self.net1(x)
x = self.net2(x)
return x
@model_wrapper
class ModelNested(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = ModelInner()
self.fc2 = nn.LayerChoice([
nn.Linear(10, 10),
nn.Linear(10, 10, bias=False)
])
self.fc3 = ModelInner()
def forward(self, x):
return self.fc3(self.fc2(self.fc1(x)))
def test_model_wrapper():
model = Model(3)
assert model.trace_symbol == Model.__wrapped__
assert model.trace_kwargs == {'in_channels': 3}
assert model.conv2.label == 'model_1'
assert model.conv3.label == 'model_2'
assert model(torch.randn(1, 3, 5, 5)).size() == torch.Size([1, 1])
model = Model(4)
assert model.trace_symbol == Model.__wrapped__
assert model.conv2.label == 'model_1' # not changed
def test_model_wrapper_nested():
model = ModelNested()
assert model.fc1.net1.label == 'model_1_1'
assert model.fc1.net2.label == 'model_1_2'
assert model.fc2.label == 'model_2'
assert model.fc3.net1.label == 'model_3_1'
assert model.fc3.net2.label == 'model_3_2'
if __name__ == '__main__':
test_model_wrapper_nested()
import json
import math
from pathlib import Path
import re
import sys
import torch
from nni.retiarii import json_dumps, json_loads, serialize
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
sys.path.insert(0, Path(__file__).parent.as_posix())
from imported.model import ImportTest
class Foo:
def __init__(self, a, b=1):
self.aa = a
self.bb = [b + 1 for _ in range(1000)]
def __eq__(self, other):
return self.aa == other.aa and self.bb == other.bb
def test_serialize():
module = serialize(Foo, 3)
assert json_loads(json_dumps(module)) == module
module = serialize(Foo, b=2, a=1)
assert json_loads(json_dumps(module)) == module
module = serialize(Foo, Foo(1), 5)
dumped_module = json_dumps(module)
assert len(dumped_module) > 200 # should not be too longer if the serialization is correct
module = serialize(Foo, serialize(Foo, 1), 5)
dumped_module = json_dumps(module)
assert len(dumped_module) < 200 # should not be too longer if the serialization is correct
assert json_loads(dumped_module) == module
def test_basic_unit():
module = ImportTest(3, 0.5)
assert json_loads(json_dumps(module)) == module
def test_dataset():
dataset = serialize(MNIST, root='data/mnist', train=False, download=True)
dataloader = serialize(DataLoader, dataset, batch_size=10)
dumped_ans = {
"__type__": "torch.utils.data.dataloader.DataLoader",
"arguments": {
"batch_size": 10,
"dataset": {
"__type__": "torchvision.datasets.mnist.MNIST",
"arguments": {"root": "data/mnist", "train": False, "download": True}
}
}
}
assert json_dumps(dataloader) == json_dumps(dumped_ans)
dataloader = json_loads(json_dumps(dumped_ans))
assert isinstance(dataloader, DataLoader)
dataset = serialize(MNIST, root='data/mnist', train=False, download=True,
transform=serialize(
transforms.Compose,
[serialize(transforms.ToTensor), serialize(transforms.Normalize, (0.1307,), (0.3081,))]
))
dataloader = serialize(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
dataset = serialize(MNIST, root='data/mnist', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
dataloader = serialize(DataLoader, dataset, batch_size=10)
x, y = next(iter(json_loads(json_dumps(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
def test_type():
assert json_dumps(torch.optim.Adam) == '{"__typename__": "torch.optim.adam.Adam"}'
assert json_loads('{"__typename__": "torch.optim.adam.Adam"}') == torch.optim.Adam
assert re.match(r'{"__typename__": "(.*)test_serializer.Foo"}', json_dumps(Foo))
assert json_dumps(math.floor) == '{"__typename__": "math.floor"}'
assert json_loads('{"__typename__": "math.floor"}') == math.floor
if __name__ == '__main__':
test_serialize()
test_basic_unit()
test_dataset()
test_type()
import math
from pathlib import Path
import re
import sys
import nni
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from nni.common.serializer import is_traceable
if True: # prevent auto formatting
sys.path.insert(0, Path(__file__).parent.as_posix())
from imported.model import ImportTest
@nni.trace
......@@ -23,8 +37,8 @@ def test_simple_class():
assert '"__kwargs__": {"a": 1, "b": 2}' in dump_str
assert '"__symbol__"' in dump_str
instance = nni.load(dump_str)
assert instance.get()._a == 1
assert instance.get()._b == 2
assert instance._a == 1
assert instance._b == 2
def test_external_class():
......@@ -44,7 +58,7 @@ def test_external_class():
r'"__kwargs__": {"in_channels": 3, "out_channels": 16, "kernel_size": 3}}'
conv = nni.load(nni.dump(conv))
assert conv.get().kernel_size == (3, 3)
assert conv.kernel_size == (3, 3)
def test_nested_class():
......@@ -53,8 +67,8 @@ def test_nested_class():
assert b._a._a == 1
dump_str = nni.dump(b)
b = nni.load(dump_str)
assert repr(b) == 'SerializableObject(type=SimpleClass, a=SerializableObject(type=SimpleClass, a=1, b=2))'
assert b.get()._a._a == 1
assert 'SimpleClass object at' in repr(b)
assert b._a._a == 1
def test_unserializable():
......@@ -64,8 +78,137 @@ def test_unserializable():
assert a._a == 1
def test_function():
t = nni.trace(math.sqrt, kw_only=False)(3)
assert 1 < t < 2
assert t.trace_symbol == math.sqrt
assert t.trace_args == [3]
t = nni.load(nni.dump(t))
assert 1 < t < 2
assert not is_traceable(t) # trace not recovered, expected, limitation
def simple_class_factory(bb=3.):
return SimpleClass(1, bb)
t = nni.trace(simple_class_factory)(4)
ts = nni.dump(t)
assert '__kwargs__' in ts
t = nni.load(ts)
assert t._a == 1
assert is_traceable(t)
t = t.trace_copy()
assert is_traceable(t)
assert t.trace_symbol(10)._b == 10
assert t.trace_kwargs['bb'] == 4
assert is_traceable(t.trace_copy())
class Foo:
def __init__(self, a, b=1):
self.aa = a
self.bb = [b + 1 for _ in range(1000)]
def __eq__(self, other):
return self.aa == other.aa and self.bb == other.bb
def test_custom_class():
module = nni.trace(Foo)(3)
assert nni.load(nni.dump(module)) == module
module = nni.trace(Foo)(b=2, a=1)
assert nni.load(nni.dump(module)) == module
module = nni.trace(Foo)(Foo(1), 5)
dumped_module = nni.dump(module)
assert len(dumped_module) > 200 # should not be too longer if the serialization is correct
module = nni.trace(Foo)(nni.trace(Foo)(1), 5)
dumped_module = nni.dump(module)
assert nni.load(dumped_module) == module
class Foo:
def __init__(self, a, b=1):
self.aa = a
self.bb = [b + 1 for _ in range(1000)]
def __eq__(self, other):
return self.aa == other.aa and self.bb == other.bb
def test_basic_unit_and_custom_import():
module = ImportTest(3, 0.5)
ss = nni.dump(module)
assert ss == r'{"__symbol__": "path:imported.model.ImportTest", "__kwargs__": {"foo": 3, "bar": 0.5}}'
assert nni.load(nni.dump(module)) == module
import nni.retiarii.nn.pytorch as nn
module = nn.Conv2d(3, 10, 3, bias=False)
ss = nni.dump(module)
assert ss == r'{"__symbol__": "path:torch.nn.modules.conv.Conv2d", "__kwargs__": {"in_channels": 3, "out_channels": 10, "kernel_size": 3, "bias": false}}'
assert nni.load(ss).bias is None
def test_dataset():
dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True)
dataloader = nni.trace(DataLoader)(dataset, batch_size=10)
dumped_ans = {
"__symbol__": "path:torch.utils.data.dataloader.DataLoader",
"__kwargs__": {
"dataset": {
"__symbol__": "path:torchvision.datasets.mnist.MNIST",
"__kwargs__": {"root": "data/mnist", "train": False, "download": True}
},
"batch_size": 10
}
}
print(nni.dump(dataloader))
print(nni.dump(dumped_ans))
assert nni.dump(dataloader) == nni.dump(dumped_ans)
dataloader = nni.load(nni.dump(dumped_ans))
assert isinstance(dataloader, DataLoader)
dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True,
transform=nni.trace(transforms.Compose)([
nni.trace(transforms.ToTensor)(),
nni.trace(transforms.Normalize)((0.1307,), (0.3081,))
]))
dataloader = nni.trace(DataLoader)(dataset, batch_size=10)
x, y = next(iter(nni.load(nni.dump(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True,
transform=nni.trace(transforms.Compose)(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
))
dataloader = nni.trace(DataLoader)(dataset, batch_size=10)
x, y = next(iter(nni.load(nni.dump(dataloader))))
assert x.size() == torch.Size([10, 1, 28, 28])
assert y.size() == torch.Size([10])
def test_type():
assert nni.dump(torch.optim.Adam) == '{"__nni_type__": "path:torch.optim.adam.Adam"}'
assert nni.load('{"__nni_type__": "path:torch.optim.adam.Adam"}') == torch.optim.Adam
assert Foo == nni.load(nni.dump(Foo))
assert nni.dump(math.floor) == '{"__nni_type__": "path:math.floor"}'
assert nni.load('{"__nni_type__": "path:math.floor"}') == math.floor
def test_lightning_earlystop():
import nni.retiarii.evaluator.pytorch.lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
trainer = pl.Trainer(callbacks=[nni.trace(EarlyStopping)(monitor="val_loss")])
trainer = nni.load(nni.dump(trainer))
assert any(isinstance(callback, EarlyStopping) for callback in trainer.callbacks)
if __name__ == '__main__':
test_simple_class()
test_external_class()
test_nested_class()
test_unserializable()
# test_simple_class()
# test_external_class()
# test_nested_class()
# test_unserializable()
# test_basic_unit()
test_type()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment