Unverified Commit cae4308f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Rename APIs and refine documentation (#3404)

parent d047d6f4
......@@ -25,14 +25,14 @@ Type hint for edge's endpoint. The int indicates nodes' order.
"""
class TrainingConfig(abc.ABC):
class Evaluator(abc.ABC):
"""
Training config of a model. A training config should define where the training code is, and the configuration of
Evaluator of a model. An evaluator should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional training config might directly import the function and call the function.
For example, functional evaluator might directly import the function and call the function.
"""
def __repr__(self):
......@@ -40,15 +40,15 @@ class TrainingConfig(abc.ABC):
return f'{self.__class__.__name__}({items})'
@abc.abstractstaticmethod
def _load(ir: Any) -> 'TrainingConfig':
def _load(ir: Any) -> 'Evaluator':
pass
@staticmethod
def _load_with_type(type_name: str, ir: Any) -> 'Optional[TrainingConfig]':
def _load_with_type(type_name: str, ir: Any) -> 'Optional[Evaluator]':
if type_name == '_debug_no_trainer':
return DebugTraining()
return DebugEvaluator()
config_cls = import_(type_name)
assert issubclass(config_cls, TrainingConfig)
assert issubclass(config_cls, Evaluator)
return config_cls._load(ir)
@abc.abstractmethod
......@@ -83,8 +83,8 @@ class Model:
The outermost graph which usually takes dataset as input and feeds output to loss function.
graphs
All graphs (subgraphs) in this model.
training_config
Training config
evaluator
Model evaluator
history
Mutation history.
`self` is directly mutated from `self.history[-1]`;
......@@ -104,7 +104,7 @@ class Model:
self._root_graph_name: str = '_model'
self.graphs: Dict[str, Graph] = {}
self.training_config: Optional[TrainingConfig] = None
self.evaluator: Optional[Evaluator] = None
self.history: List[Model] = []
......@@ -113,7 +113,7 @@ class Model:
def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'training_config={self.training_config}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
@property
def root_graph(self) -> 'Graph':
......@@ -131,7 +131,7 @@ class Model:
new_model = Model(_internal=True)
new_model._root_graph_name = self._root_graph_name
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.training_config = copy.deepcopy(self.training_config) # TODO this may be a problem when training config is large
new_model.evaluator = copy.deepcopy(self.evaluator) # TODO this may be a problem when evaluator is large
new_model.history = self.history + [self]
return new_model
......@@ -139,16 +139,16 @@ class Model:
def _load(ir: Any) -> 'Model':
model = Model(_internal=True)
for graph_name, graph_data in ir.items():
if graph_name != '_training_config':
if graph_name != '_evaluator':
Graph._load(model, graph_name, graph_data)._register()
model.training_config = TrainingConfig._load_with_type(ir['_training_config']['__type__'], ir['_training_config'])
model.evaluator = Evaluator._load_with_type(ir['_evaluator']['__type__'], ir['_evaluator'])
return model
def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()}
ret['_training_config'] = {
'__type__': get_full_class_name(self.training_config.__class__),
**self.training_config._dump()
ret['_evaluator'] = {
'__type__': get_full_class_name(self.evaluator.__class__),
**self.evaluator._dump()
}
return ret
......@@ -681,10 +681,10 @@ class IllegalGraphError(ValueError):
json.dump(graph, dump_file, indent=4)
class DebugTraining(TrainingConfig):
class DebugEvaluator(Evaluator):
@staticmethod
def _load(ir: Any) -> 'DebugTraining':
return DebugTraining()
def _load(ir: Any) -> 'DebugEvaluator':
return DebugEvaluator()
def _dump(self) -> Any:
return {'__type__': '_debug_no_trainer'}
......
......@@ -11,7 +11,7 @@ from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
from .execution.api import set_execution_engine
from .integration_api import register_advisor
from .utils import json_dumps, json_loads
from .serializer import json_dumps, json_loads
_logger = logging.getLogger(__name__)
......
......@@ -3,7 +3,7 @@ from typing import NewType, Any
import nni
from .utils import json_loads
from .serializer import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
......
......@@ -5,7 +5,8 @@ import warnings
import torch
import torch.nn as nn
from ...utils import uid, add_record, del_record, Translatable
from ...serializer import Translatable, basic_unit
from ...utils import uid
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
......@@ -281,21 +282,18 @@ class ValueChoice(Translatable, nn.Module):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'
@basic_unit
class Placeholder(nn.Module):
# TODO: docstring
def __init__(self, label, related_info):
add_record(id(self), related_info)
def __init__(self, label, **related_info):
self.label = label
self.related_info = related_info
super(Placeholder, self).__init__()
super().__init__()
def forward(self, x):
return x
def __del__(self):
del_record(id(self))
class ChosenInputs(nn.Module):
"""
......
import torch
import torch.nn as nn
from ...utils import add_record, blackbox_module, del_record, version_larger_equal
from ...serializer import basic_unit
from ...serializer import transparent_serialize
from ...utils import version_larger_equal
# NOTE: support pytorch version >= 1.5.0
......@@ -36,135 +38,119 @@ if version_larger_equal(torch.__version__, '1.7.0'):
Module = nn.Module
class Sequential(nn.Sequential):
def __init__(self, *args):
add_record(id(self), {})
super(Sequential, self).__init__(*args)
def __del__(self):
del_record(id(self))
class ModuleList(nn.ModuleList):
def __init__(self, *args):
add_record(id(self), {})
super(ModuleList, self).__init__(*args)
def __del__(self):
del_record(id(self))
Identity = blackbox_module(nn.Identity)
Linear = blackbox_module(nn.Linear)
Conv1d = blackbox_module(nn.Conv1d)
Conv2d = blackbox_module(nn.Conv2d)
Conv3d = blackbox_module(nn.Conv3d)
ConvTranspose1d = blackbox_module(nn.ConvTranspose1d)
ConvTranspose2d = blackbox_module(nn.ConvTranspose2d)
ConvTranspose3d = blackbox_module(nn.ConvTranspose3d)
Threshold = blackbox_module(nn.Threshold)
ReLU = blackbox_module(nn.ReLU)
Hardtanh = blackbox_module(nn.Hardtanh)
ReLU6 = blackbox_module(nn.ReLU6)
Sigmoid = blackbox_module(nn.Sigmoid)
Tanh = blackbox_module(nn.Tanh)
Softmax = blackbox_module(nn.Softmax)
Softmax2d = blackbox_module(nn.Softmax2d)
LogSoftmax = blackbox_module(nn.LogSoftmax)
ELU = blackbox_module(nn.ELU)
SELU = blackbox_module(nn.SELU)
CELU = blackbox_module(nn.CELU)
GLU = blackbox_module(nn.GLU)
GELU = blackbox_module(nn.GELU)
Hardshrink = blackbox_module(nn.Hardshrink)
LeakyReLU = blackbox_module(nn.LeakyReLU)
LogSigmoid = blackbox_module(nn.LogSigmoid)
Softplus = blackbox_module(nn.Softplus)
Softshrink = blackbox_module(nn.Softshrink)
MultiheadAttention = blackbox_module(nn.MultiheadAttention)
PReLU = blackbox_module(nn.PReLU)
Softsign = blackbox_module(nn.Softsign)
Softmin = blackbox_module(nn.Softmin)
Tanhshrink = blackbox_module(nn.Tanhshrink)
RReLU = blackbox_module(nn.RReLU)
AvgPool1d = blackbox_module(nn.AvgPool1d)
AvgPool2d = blackbox_module(nn.AvgPool2d)
AvgPool3d = blackbox_module(nn.AvgPool3d)
MaxPool1d = blackbox_module(nn.MaxPool1d)
MaxPool2d = blackbox_module(nn.MaxPool2d)
MaxPool3d = blackbox_module(nn.MaxPool3d)
MaxUnpool1d = blackbox_module(nn.MaxUnpool1d)
MaxUnpool2d = blackbox_module(nn.MaxUnpool2d)
MaxUnpool3d = blackbox_module(nn.MaxUnpool3d)
FractionalMaxPool2d = blackbox_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = blackbox_module(nn.FractionalMaxPool3d)
LPPool1d = blackbox_module(nn.LPPool1d)
LPPool2d = blackbox_module(nn.LPPool2d)
LocalResponseNorm = blackbox_module(nn.LocalResponseNorm)
BatchNorm1d = blackbox_module(nn.BatchNorm1d)
BatchNorm2d = blackbox_module(nn.BatchNorm2d)
BatchNorm3d = blackbox_module(nn.BatchNorm3d)
InstanceNorm1d = blackbox_module(nn.InstanceNorm1d)
InstanceNorm2d = blackbox_module(nn.InstanceNorm2d)
InstanceNorm3d = blackbox_module(nn.InstanceNorm3d)
LayerNorm = blackbox_module(nn.LayerNorm)
GroupNorm = blackbox_module(nn.GroupNorm)
SyncBatchNorm = blackbox_module(nn.SyncBatchNorm)
Dropout = blackbox_module(nn.Dropout)
Dropout2d = blackbox_module(nn.Dropout2d)
Dropout3d = blackbox_module(nn.Dropout3d)
AlphaDropout = blackbox_module(nn.AlphaDropout)
FeatureAlphaDropout = blackbox_module(nn.FeatureAlphaDropout)
ReflectionPad1d = blackbox_module(nn.ReflectionPad1d)
ReflectionPad2d = blackbox_module(nn.ReflectionPad2d)
ReplicationPad2d = blackbox_module(nn.ReplicationPad2d)
ReplicationPad1d = blackbox_module(nn.ReplicationPad1d)
ReplicationPad3d = blackbox_module(nn.ReplicationPad3d)
CrossMapLRN2d = blackbox_module(nn.CrossMapLRN2d)
Embedding = blackbox_module(nn.Embedding)
EmbeddingBag = blackbox_module(nn.EmbeddingBag)
RNNBase = blackbox_module(nn.RNNBase)
RNN = blackbox_module(nn.RNN)
LSTM = blackbox_module(nn.LSTM)
GRU = blackbox_module(nn.GRU)
RNNCellBase = blackbox_module(nn.RNNCellBase)
RNNCell = blackbox_module(nn.RNNCell)
LSTMCell = blackbox_module(nn.LSTMCell)
GRUCell = blackbox_module(nn.GRUCell)
PixelShuffle = blackbox_module(nn.PixelShuffle)
Upsample = blackbox_module(nn.Upsample)
UpsamplingNearest2d = blackbox_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = blackbox_module(nn.UpsamplingBilinear2d)
PairwiseDistance = blackbox_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = blackbox_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = blackbox_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = blackbox_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = blackbox_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = blackbox_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = blackbox_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = blackbox_module(nn.TripletMarginLoss)
ZeroPad2d = blackbox_module(nn.ZeroPad2d)
ConstantPad1d = blackbox_module(nn.ConstantPad1d)
ConstantPad2d = blackbox_module(nn.ConstantPad2d)
ConstantPad3d = blackbox_module(nn.ConstantPad3d)
Bilinear = blackbox_module(nn.Bilinear)
CosineSimilarity = blackbox_module(nn.CosineSimilarity)
Unfold = blackbox_module(nn.Unfold)
Fold = blackbox_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = blackbox_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = blackbox_module(nn.TransformerEncoder)
TransformerDecoder = blackbox_module(nn.TransformerDecoder)
TransformerEncoderLayer = blackbox_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = blackbox_module(nn.TransformerDecoderLayer)
Transformer = blackbox_module(nn.Transformer)
Flatten = blackbox_module(nn.Flatten)
Hardsigmoid = blackbox_module(nn.Hardsigmoid)
Sequential = transparent_serialize(nn.Sequential)
ModuleList = transparent_serialize(nn.ModuleList)
Identity = basic_unit(nn.Identity)
Linear = basic_unit(nn.Linear)
Conv1d = basic_unit(nn.Conv1d)
Conv2d = basic_unit(nn.Conv2d)
Conv3d = basic_unit(nn.Conv3d)
ConvTranspose1d = basic_unit(nn.ConvTranspose1d)
ConvTranspose2d = basic_unit(nn.ConvTranspose2d)
ConvTranspose3d = basic_unit(nn.ConvTranspose3d)
Threshold = basic_unit(nn.Threshold)
ReLU = basic_unit(nn.ReLU)
Hardtanh = basic_unit(nn.Hardtanh)
ReLU6 = basic_unit(nn.ReLU6)
Sigmoid = basic_unit(nn.Sigmoid)
Tanh = basic_unit(nn.Tanh)
Softmax = basic_unit(nn.Softmax)
Softmax2d = basic_unit(nn.Softmax2d)
LogSoftmax = basic_unit(nn.LogSoftmax)
ELU = basic_unit(nn.ELU)
SELU = basic_unit(nn.SELU)
CELU = basic_unit(nn.CELU)
GLU = basic_unit(nn.GLU)
GELU = basic_unit(nn.GELU)
Hardshrink = basic_unit(nn.Hardshrink)
LeakyReLU = basic_unit(nn.LeakyReLU)
LogSigmoid = basic_unit(nn.LogSigmoid)
Softplus = basic_unit(nn.Softplus)
Softshrink = basic_unit(nn.Softshrink)
MultiheadAttention = basic_unit(nn.MultiheadAttention)
PReLU = basic_unit(nn.PReLU)
Softsign = basic_unit(nn.Softsign)
Softmin = basic_unit(nn.Softmin)
Tanhshrink = basic_unit(nn.Tanhshrink)
RReLU = basic_unit(nn.RReLU)
AvgPool1d = basic_unit(nn.AvgPool1d)
AvgPool2d = basic_unit(nn.AvgPool2d)
AvgPool3d = basic_unit(nn.AvgPool3d)
MaxPool1d = basic_unit(nn.MaxPool1d)
MaxPool2d = basic_unit(nn.MaxPool2d)
MaxPool3d = basic_unit(nn.MaxPool3d)
MaxUnpool1d = basic_unit(nn.MaxUnpool1d)
MaxUnpool2d = basic_unit(nn.MaxUnpool2d)
MaxUnpool3d = basic_unit(nn.MaxUnpool3d)
FractionalMaxPool2d = basic_unit(nn.FractionalMaxPool2d)
FractionalMaxPool3d = basic_unit(nn.FractionalMaxPool3d)
LPPool1d = basic_unit(nn.LPPool1d)
LPPool2d = basic_unit(nn.LPPool2d)
LocalResponseNorm = basic_unit(nn.LocalResponseNorm)
BatchNorm1d = basic_unit(nn.BatchNorm1d)
BatchNorm2d = basic_unit(nn.BatchNorm2d)
BatchNorm3d = basic_unit(nn.BatchNorm3d)
InstanceNorm1d = basic_unit(nn.InstanceNorm1d)
InstanceNorm2d = basic_unit(nn.InstanceNorm2d)
InstanceNorm3d = basic_unit(nn.InstanceNorm3d)
LayerNorm = basic_unit(nn.LayerNorm)
GroupNorm = basic_unit(nn.GroupNorm)
SyncBatchNorm = basic_unit(nn.SyncBatchNorm)
Dropout = basic_unit(nn.Dropout)
Dropout2d = basic_unit(nn.Dropout2d)
Dropout3d = basic_unit(nn.Dropout3d)
AlphaDropout = basic_unit(nn.AlphaDropout)
FeatureAlphaDropout = basic_unit(nn.FeatureAlphaDropout)
ReflectionPad1d = basic_unit(nn.ReflectionPad1d)
ReflectionPad2d = basic_unit(nn.ReflectionPad2d)
ReplicationPad2d = basic_unit(nn.ReplicationPad2d)
ReplicationPad1d = basic_unit(nn.ReplicationPad1d)
ReplicationPad3d = basic_unit(nn.ReplicationPad3d)
CrossMapLRN2d = basic_unit(nn.CrossMapLRN2d)
Embedding = basic_unit(nn.Embedding)
EmbeddingBag = basic_unit(nn.EmbeddingBag)
RNNBase = basic_unit(nn.RNNBase)
RNN = basic_unit(nn.RNN)
LSTM = basic_unit(nn.LSTM)
GRU = basic_unit(nn.GRU)
RNNCellBase = basic_unit(nn.RNNCellBase)
RNNCell = basic_unit(nn.RNNCell)
LSTMCell = basic_unit(nn.LSTMCell)
GRUCell = basic_unit(nn.GRUCell)
PixelShuffle = basic_unit(nn.PixelShuffle)
Upsample = basic_unit(nn.Upsample)
UpsamplingNearest2d = basic_unit(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = basic_unit(nn.UpsamplingBilinear2d)
PairwiseDistance = basic_unit(nn.PairwiseDistance)
AdaptiveMaxPool1d = basic_unit(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = basic_unit(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = basic_unit(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = basic_unit(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = basic_unit(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = basic_unit(nn.AdaptiveAvgPool3d)
TripletMarginLoss = basic_unit(nn.TripletMarginLoss)
ZeroPad2d = basic_unit(nn.ZeroPad2d)
ConstantPad1d = basic_unit(nn.ConstantPad1d)
ConstantPad2d = basic_unit(nn.ConstantPad2d)
ConstantPad3d = basic_unit(nn.ConstantPad3d)
Bilinear = basic_unit(nn.Bilinear)
CosineSimilarity = basic_unit(nn.CosineSimilarity)
Unfold = basic_unit(nn.Unfold)
Fold = basic_unit(nn.Fold)
AdaptiveLogSoftmaxWithLoss = basic_unit(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = basic_unit(nn.TransformerEncoder)
TransformerDecoder = basic_unit(nn.TransformerDecoder)
TransformerEncoderLayer = basic_unit(nn.TransformerEncoderLayer)
TransformerDecoderLayer = basic_unit(nn.TransformerDecoderLayer)
Transformer = basic_unit(nn.Transformer)
Flatten = basic_unit(nn.Flatten)
Hardsigmoid = basic_unit(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = blackbox_module(nn.Hardswish)
Hardswish = basic_unit(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = blackbox_module(nn.SiLU)
Unflatten = blackbox_module(nn.Unflatten)
TripletMarginWithDistanceLoss = blackbox_module(nn.TripletMarginWithDistanceLoss)
SiLU = basic_unit(nn.SiLU)
Unflatten = basic_unit(nn.Unflatten)
TripletMarginWithDistanceLoss = basic_unit(nn.TripletMarginWithDistanceLoss)
from .functional import FunctionalTrainer
from .interface import BaseOneShotTrainer
......@@ -2,11 +2,6 @@ import abc
from typing import Any
class BaseTrainer(abc.ABC):
# Deprecated class
pass
class BaseOneShotTrainer(abc.ABC):
"""
Build many (possibly all) architectures into a full graph, search (with train) and export the best.
......
from .base import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from .darts import DartsTrainer
from .enas import EnasTrainer
from .proxyless import ProxylessTrainer
from .random import RandomTrainer, SinglePathTrainer
from .random import SinglePathTrainer, RandomTrainer
from .utils import replace_input_choice, replace_layer_choice
import abc
import functools
import inspect
from typing import Any
import json_tricks
from .utils import get_full_class_name, get_module_name, import_
def get_init_parameters_or_fail(obj, silently=False):
if hasattr(obj, '_init_parameters'):
return obj._init_parameters
elif silently:
return None
else:
raise ValueError(f'Object {obj} needs to be serializable but `_init_parameters` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use serialize or @serialize_cls.')
### This is a patch of json-tricks to make it more useful to us ###
def _serialize_class_instance_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
try: # FIXME: raise error
if hasattr(obj, '__class__'):
return {
'__type__': get_full_class_name(obj.__class__),
'arguments': get_init_parameters_or_fail(obj)
}
except ValueError:
pass
return obj
def _serialize_class_instance_decode(obj):
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
return import_(obj['__type__'])(**obj['arguments'])
return obj
def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type):
return {'__typename__': get_full_class_name(obj, relocate_module=True)}
return obj
def _type_decode(obj):
if isinstance(obj, dict) and '__typename__' in obj:
return import_(obj['__typename__'])
return obj
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
### End of json-tricks patch ###
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@abc.abstractmethod
def _translate(self) -> Any:
pass
def _create_wrapper_cls(cls, store_init_parameters=True):
class wrapper(cls):
def __init__(self, *args, **kwargs):
if store_init_parameters:
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {}
full_args.update(kwargs)
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value._translate()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value._translate()
self._init_parameters = full_args
else:
self._init_parameters = {}
super().__init__(*args, **kwargs)
wrapper.__module__ = get_module_name(cls)
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
return wrapper
def serialize_cls(cls):
"""
To create an serializable class.
"""
return _create_wrapper_cls(cls)
def transparent_serialize(cls):
"""
Wrap a module but does not record parameters. For internal use only.
"""
return _create_wrapper_cls(cls, store_init_parameters=False)
def serialize(cls, *args, **kwargs):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
self.op = serialize(MyCustomOp, hidden_units=128)
"""
return serialize_cls(cls)(*args, **kwargs)
def basic_unit(cls):
"""
To wrap a module as a basic unit, to stop it from parsing and make it mutate-able.
"""
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
return serialize_cls(cls)
import abc
import functools
import inspect
from collections import defaultdict
from typing import Any
from pathlib import Path
import json_tricks
def import_(target: str, allow_none: bool = False) -> Any:
if target is None:
......@@ -23,145 +19,6 @@ def version_larger_equal(a: str, b: str) -> bool:
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
### This is a patch of json-tricks to make it more useful to us ###
def _blackbox_class_instance_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if hasattr(obj, '__class__') and hasattr(obj, '__init_parameters__'):
return {
'__type__': get_full_class_name(obj.__class__),
'arguments': obj.__init_parameters__
}
return obj
def _blackbox_class_instance_decode(obj):
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
return import_(obj['__type__'])(**obj['arguments'])
return obj
def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type):
return {'__typename__': get_full_class_name(obj, relocate_module=True)}
return obj
def _type_decode(obj):
if isinstance(obj, dict) and '__typename__' in obj:
return import_(obj['__typename__'])
return obj
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_blackbox_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_blackbox_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_blackbox_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_blackbox_class_instance_encode, _type_encode])
### End of json-tricks patch ###
_records = {}
def get_records():
global _records
return _records
def clear_records():
global _records
_records = {}
def add_record(key, value):
"""
"""
global _records
if _records is not None:
assert key not in _records, f'{key} already in _records. Conflict: {_records[key]}'
_records[key] = value
def del_record(key):
global _records
if _records is not None:
_records.pop(key, None)
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@abc.abstractmethod
def _translate(self) -> Any:
pass
def _blackbox_cls(cls):
class wrapper(cls):
def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {}
full_args.update(kwargs)
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value._translate()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value._translate()
add_record(id(self), full_args) # for compatibility. Will remove soon.
self.__init_parameters__ = full_args
super().__init__(*args, **kwargs)
def __del__(self):
del_record(id(self))
wrapper.__module__ = _get_module_name(cls)
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
return wrapper
def blackbox(cls, *args, **kwargs):
"""
To create an blackbox instance inline without decorator. For example,
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
return _blackbox_cls(cls)(*args, **kwargs)
def blackbox_module(cls):
"""
Register a module. Use it as a decorator.
"""
return _blackbox_cls(cls)
def register_trainer(cls):
"""
Register a trainer. Use it as a decorator.
"""
return _blackbox_cls(cls)
_last_uid = defaultdict(int)
......@@ -170,7 +27,7 @@ def uid(namespace: str = 'default') -> int:
return _last_uid[namespace]
def _get_module_name(cls):
def get_module_name(cls):
module_name = cls.__module__
if module_name == '__main__':
# infer the module name with inspect
......@@ -180,7 +37,7 @@ def _get_module_name(cls):
main_file_path = Path(inspect.getsourcefile(frm[0]))
if main_file_path.parents[0] != Path('.'):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
break
......@@ -195,5 +52,5 @@ def _get_module_name(cls):
def get_full_class_name(cls, relocate_module=False):
module_name = _get_module_name(cls) if relocate_module else cls.__module__
module_name = get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__
......@@ -7,9 +7,9 @@ import torch.nn as torch_nn
import ops
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii import basic_unit
@blackbox_module
@basic_unit
class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
......
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module
from nni.retiarii import basic_unit
@blackbox_module
@basic_unit
class DropPath(nn.Module):
def __init__(self, p=0.):
"""
......@@ -24,7 +24,7 @@ class DropPath(nn.Module):
return x
@blackbox_module
@basic_unit
class PoolBN(nn.Module):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
......@@ -45,7 +45,7 @@ class PoolBN(nn.Module):
out = self.bn(out)
return out
@blackbox_module
@basic_unit
class StdConv(nn.Module):
"""
Standard conv: ReLU - Conv - BN
......@@ -61,7 +61,7 @@ class StdConv(nn.Module):
def forward(self, x):
return self.net(x)
@blackbox_module
@basic_unit
class FacConv(nn.Module):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
......@@ -78,7 +78,7 @@ class FacConv(nn.Module):
def forward(self, x):
return self.net(x)
@blackbox_module
@basic_unit
class DilConv(nn.Module):
"""
(Dilated) depthwise separable conv.
......@@ -98,7 +98,7 @@ class DilConv(nn.Module):
def forward(self, x):
return self.net(x)
@blackbox_module
@basic_unit
class SepConv(nn.Module):
"""
Depthwise separable conv.
......@@ -114,7 +114,7 @@ class SepConv(nn.Module):
def forward(self, x):
return self.net(x)
@blackbox_module
@basic_unit
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise (stride=2).
......
......@@ -4,9 +4,9 @@ import sys
import torch
from pathlib import Path
import nni.retiarii.trainer.pytorch.lightning as pl
import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.strategy as strategy
from nni.retiarii import blackbox_module as bm
from nni.retiarii import serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from torchvision import transforms
from torchvision.datasets import CIFAR10
......@@ -27,8 +27,8 @@ if __name__ == '__main__':
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = bm(CIFAR10)(root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = bm(CIFAR10)(root='data/cifar10', train=False, download=True, transform=valid_transform)
train_dataset = serialize(CIFAR10, root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2)
......
......@@ -9,7 +9,7 @@ from torchvision import transforms
from torchvision.datasets import CIFAR10
from nni.retiarii.experiment.pytorch import RetiariiExperiment
from nni.retiarii.trainer.pytorch import DartsTrainer
from nni.retiarii.oneshot.pytorch import DartsTrainer
from darts_model import CNN
......
from nni.retiarii import blackbox_module
from nni.retiarii import basic_unit
import nni.retiarii.nn.pytorch as nn
import warnings
......@@ -148,7 +148,7 @@ class MNASNet(nn.Module):
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides):
# TODO: restrict that "choose" can only be used within mutator
ph = nn.Placeholder(label=f'mutable_{count}', related_info={
ph = nn.Placeholder(label=f'mutable_{count}', **{
'kernel_size_options': [1, 3, 5],
'n_layer_options': [1, 2, 3, 4],
'op_type_options': ['__mutated__.base_mnasnet.RegularConv',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment