Unverified Commit cae4308f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Rename APIs and refine documentation (#3404)

parent d047d6f4
...@@ -25,14 +25,14 @@ Type hint for edge's endpoint. The int indicates nodes' order. ...@@ -25,14 +25,14 @@ Type hint for edge's endpoint. The int indicates nodes' order.
""" """
class TrainingConfig(abc.ABC): class Evaluator(abc.ABC):
""" """
Training config of a model. A training config should define where the training code is, and the configuration of Evaluator of a model. An evaluator should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs) training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code. or tune-able parameters (such as learning rate), depending on the implementation of training code.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class. Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional training config might directly import the function and call the function. For example, functional evaluator might directly import the function and call the function.
""" """
def __repr__(self): def __repr__(self):
...@@ -40,15 +40,15 @@ class TrainingConfig(abc.ABC): ...@@ -40,15 +40,15 @@ class TrainingConfig(abc.ABC):
return f'{self.__class__.__name__}({items})' return f'{self.__class__.__name__}({items})'
@abc.abstractstaticmethod @abc.abstractstaticmethod
def _load(ir: Any) -> 'TrainingConfig': def _load(ir: Any) -> 'Evaluator':
pass pass
@staticmethod @staticmethod
def _load_with_type(type_name: str, ir: Any) -> 'Optional[TrainingConfig]': def _load_with_type(type_name: str, ir: Any) -> 'Optional[Evaluator]':
if type_name == '_debug_no_trainer': if type_name == '_debug_no_trainer':
return DebugTraining() return DebugEvaluator()
config_cls = import_(type_name) config_cls = import_(type_name)
assert issubclass(config_cls, TrainingConfig) assert issubclass(config_cls, Evaluator)
return config_cls._load(ir) return config_cls._load(ir)
@abc.abstractmethod @abc.abstractmethod
...@@ -83,8 +83,8 @@ class Model: ...@@ -83,8 +83,8 @@ class Model:
The outermost graph which usually takes dataset as input and feeds output to loss function. The outermost graph which usually takes dataset as input and feeds output to loss function.
graphs graphs
All graphs (subgraphs) in this model. All graphs (subgraphs) in this model.
training_config evaluator
Training config Model evaluator
history history
Mutation history. Mutation history.
`self` is directly mutated from `self.history[-1]`; `self` is directly mutated from `self.history[-1]`;
...@@ -104,7 +104,7 @@ class Model: ...@@ -104,7 +104,7 @@ class Model:
self._root_graph_name: str = '_model' self._root_graph_name: str = '_model'
self.graphs: Dict[str, Graph] = {} self.graphs: Dict[str, Graph] = {}
self.training_config: Optional[TrainingConfig] = None self.evaluator: Optional[Evaluator] = None
self.history: List[Model] = [] self.history: List[Model] = []
...@@ -113,7 +113,7 @@ class Model: ...@@ -113,7 +113,7 @@ class Model:
def __repr__(self): def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \ return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'training_config={self.training_config}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})' f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
@property @property
def root_graph(self) -> 'Graph': def root_graph(self) -> 'Graph':
...@@ -131,7 +131,7 @@ class Model: ...@@ -131,7 +131,7 @@ class Model:
new_model = Model(_internal=True) new_model = Model(_internal=True)
new_model._root_graph_name = self._root_graph_name new_model._root_graph_name = self._root_graph_name
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()} new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.training_config = copy.deepcopy(self.training_config) # TODO this may be a problem when training config is large new_model.evaluator = copy.deepcopy(self.evaluator) # TODO this may be a problem when evaluator is large
new_model.history = self.history + [self] new_model.history = self.history + [self]
return new_model return new_model
...@@ -139,16 +139,16 @@ class Model: ...@@ -139,16 +139,16 @@ class Model:
def _load(ir: Any) -> 'Model': def _load(ir: Any) -> 'Model':
model = Model(_internal=True) model = Model(_internal=True)
for graph_name, graph_data in ir.items(): for graph_name, graph_data in ir.items():
if graph_name != '_training_config': if graph_name != '_evaluator':
Graph._load(model, graph_name, graph_data)._register() Graph._load(model, graph_name, graph_data)._register()
model.training_config = TrainingConfig._load_with_type(ir['_training_config']['__type__'], ir['_training_config']) model.evaluator = Evaluator._load_with_type(ir['_evaluator']['__type__'], ir['_evaluator'])
return model return model
def _dump(self) -> Any: def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()} ret = {name: graph._dump() for name, graph in self.graphs.items()}
ret['_training_config'] = { ret['_evaluator'] = {
'__type__': get_full_class_name(self.training_config.__class__), '__type__': get_full_class_name(self.evaluator.__class__),
**self.training_config._dump() **self.evaluator._dump()
} }
return ret return ret
...@@ -681,10 +681,10 @@ class IllegalGraphError(ValueError): ...@@ -681,10 +681,10 @@ class IllegalGraphError(ValueError):
json.dump(graph, dump_file, indent=4) json.dump(graph, dump_file, indent=4)
class DebugTraining(TrainingConfig): class DebugEvaluator(Evaluator):
@staticmethod @staticmethod
def _load(ir: Any) -> 'DebugTraining': def _load(ir: Any) -> 'DebugEvaluator':
return DebugTraining() return DebugEvaluator()
def _dump(self) -> Any: def _dump(self) -> Any:
return {'__type__': '_debug_no_trainer'} return {'__type__': '_debug_no_trainer'}
......
...@@ -11,7 +11,7 @@ from .execution.base import BaseExecutionEngine ...@@ -11,7 +11,7 @@ from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine from .execution.cgo_engine import CGOExecutionEngine
from .execution.api import set_execution_engine from .execution.api import set_execution_engine
from .integration_api import register_advisor from .integration_api import register_advisor
from .utils import json_dumps, json_loads from .serializer import json_dumps, json_loads
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -3,7 +3,7 @@ from typing import NewType, Any ...@@ -3,7 +3,7 @@ from typing import NewType, Any
import nni import nni
from .utils import json_loads from .serializer import json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor # NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import # because it would induce cycled import
......
...@@ -5,7 +5,8 @@ import warnings ...@@ -5,7 +5,8 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import uid, add_record, del_record, Translatable from ...serializer import Translatable, basic_unit
from ...utils import uid
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs'] __all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
...@@ -281,21 +282,18 @@ class ValueChoice(Translatable, nn.Module): ...@@ -281,21 +282,18 @@ class ValueChoice(Translatable, nn.Module):
return f'ValueChoice({self.candidates}, label={repr(self.label)})' return f'ValueChoice({self.candidates}, label={repr(self.label)})'
@basic_unit
class Placeholder(nn.Module): class Placeholder(nn.Module):
# TODO: docstring # TODO: docstring
def __init__(self, label, related_info): def __init__(self, label, **related_info):
add_record(id(self), related_info)
self.label = label self.label = label
self.related_info = related_info self.related_info = related_info
super(Placeholder, self).__init__() super().__init__()
def forward(self, x): def forward(self, x):
return x return x
def __del__(self):
del_record(id(self))
class ChosenInputs(nn.Module): class ChosenInputs(nn.Module):
""" """
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import add_record, blackbox_module, del_record, version_larger_equal from ...serializer import basic_unit
from ...serializer import transparent_serialize
from ...utils import version_larger_equal
# NOTE: support pytorch version >= 1.5.0 # NOTE: support pytorch version >= 1.5.0
...@@ -36,135 +38,119 @@ if version_larger_equal(torch.__version__, '1.7.0'): ...@@ -36,135 +38,119 @@ if version_larger_equal(torch.__version__, '1.7.0'):
Module = nn.Module Module = nn.Module
Sequential = transparent_serialize(nn.Sequential)
class Sequential(nn.Sequential): ModuleList = transparent_serialize(nn.ModuleList)
def __init__(self, *args):
add_record(id(self), {}) Identity = basic_unit(nn.Identity)
super(Sequential, self).__init__(*args) Linear = basic_unit(nn.Linear)
Conv1d = basic_unit(nn.Conv1d)
def __del__(self): Conv2d = basic_unit(nn.Conv2d)
del_record(id(self)) Conv3d = basic_unit(nn.Conv3d)
ConvTranspose1d = basic_unit(nn.ConvTranspose1d)
ConvTranspose2d = basic_unit(nn.ConvTranspose2d)
class ModuleList(nn.ModuleList): ConvTranspose3d = basic_unit(nn.ConvTranspose3d)
def __init__(self, *args): Threshold = basic_unit(nn.Threshold)
add_record(id(self), {}) ReLU = basic_unit(nn.ReLU)
super(ModuleList, self).__init__(*args) Hardtanh = basic_unit(nn.Hardtanh)
ReLU6 = basic_unit(nn.ReLU6)
def __del__(self): Sigmoid = basic_unit(nn.Sigmoid)
del_record(id(self)) Tanh = basic_unit(nn.Tanh)
Softmax = basic_unit(nn.Softmax)
Softmax2d = basic_unit(nn.Softmax2d)
Identity = blackbox_module(nn.Identity) LogSoftmax = basic_unit(nn.LogSoftmax)
Linear = blackbox_module(nn.Linear) ELU = basic_unit(nn.ELU)
Conv1d = blackbox_module(nn.Conv1d) SELU = basic_unit(nn.SELU)
Conv2d = blackbox_module(nn.Conv2d) CELU = basic_unit(nn.CELU)
Conv3d = blackbox_module(nn.Conv3d) GLU = basic_unit(nn.GLU)
ConvTranspose1d = blackbox_module(nn.ConvTranspose1d) GELU = basic_unit(nn.GELU)
ConvTranspose2d = blackbox_module(nn.ConvTranspose2d) Hardshrink = basic_unit(nn.Hardshrink)
ConvTranspose3d = blackbox_module(nn.ConvTranspose3d) LeakyReLU = basic_unit(nn.LeakyReLU)
Threshold = blackbox_module(nn.Threshold) LogSigmoid = basic_unit(nn.LogSigmoid)
ReLU = blackbox_module(nn.ReLU) Softplus = basic_unit(nn.Softplus)
Hardtanh = blackbox_module(nn.Hardtanh) Softshrink = basic_unit(nn.Softshrink)
ReLU6 = blackbox_module(nn.ReLU6) MultiheadAttention = basic_unit(nn.MultiheadAttention)
Sigmoid = blackbox_module(nn.Sigmoid) PReLU = basic_unit(nn.PReLU)
Tanh = blackbox_module(nn.Tanh) Softsign = basic_unit(nn.Softsign)
Softmax = blackbox_module(nn.Softmax) Softmin = basic_unit(nn.Softmin)
Softmax2d = blackbox_module(nn.Softmax2d) Tanhshrink = basic_unit(nn.Tanhshrink)
LogSoftmax = blackbox_module(nn.LogSoftmax) RReLU = basic_unit(nn.RReLU)
ELU = blackbox_module(nn.ELU) AvgPool1d = basic_unit(nn.AvgPool1d)
SELU = blackbox_module(nn.SELU) AvgPool2d = basic_unit(nn.AvgPool2d)
CELU = blackbox_module(nn.CELU) AvgPool3d = basic_unit(nn.AvgPool3d)
GLU = blackbox_module(nn.GLU) MaxPool1d = basic_unit(nn.MaxPool1d)
GELU = blackbox_module(nn.GELU) MaxPool2d = basic_unit(nn.MaxPool2d)
Hardshrink = blackbox_module(nn.Hardshrink) MaxPool3d = basic_unit(nn.MaxPool3d)
LeakyReLU = blackbox_module(nn.LeakyReLU) MaxUnpool1d = basic_unit(nn.MaxUnpool1d)
LogSigmoid = blackbox_module(nn.LogSigmoid) MaxUnpool2d = basic_unit(nn.MaxUnpool2d)
Softplus = blackbox_module(nn.Softplus) MaxUnpool3d = basic_unit(nn.MaxUnpool3d)
Softshrink = blackbox_module(nn.Softshrink) FractionalMaxPool2d = basic_unit(nn.FractionalMaxPool2d)
MultiheadAttention = blackbox_module(nn.MultiheadAttention) FractionalMaxPool3d = basic_unit(nn.FractionalMaxPool3d)
PReLU = blackbox_module(nn.PReLU) LPPool1d = basic_unit(nn.LPPool1d)
Softsign = blackbox_module(nn.Softsign) LPPool2d = basic_unit(nn.LPPool2d)
Softmin = blackbox_module(nn.Softmin) LocalResponseNorm = basic_unit(nn.LocalResponseNorm)
Tanhshrink = blackbox_module(nn.Tanhshrink) BatchNorm1d = basic_unit(nn.BatchNorm1d)
RReLU = blackbox_module(nn.RReLU) BatchNorm2d = basic_unit(nn.BatchNorm2d)
AvgPool1d = blackbox_module(nn.AvgPool1d) BatchNorm3d = basic_unit(nn.BatchNorm3d)
AvgPool2d = blackbox_module(nn.AvgPool2d) InstanceNorm1d = basic_unit(nn.InstanceNorm1d)
AvgPool3d = blackbox_module(nn.AvgPool3d) InstanceNorm2d = basic_unit(nn.InstanceNorm2d)
MaxPool1d = blackbox_module(nn.MaxPool1d) InstanceNorm3d = basic_unit(nn.InstanceNorm3d)
MaxPool2d = blackbox_module(nn.MaxPool2d) LayerNorm = basic_unit(nn.LayerNorm)
MaxPool3d = blackbox_module(nn.MaxPool3d) GroupNorm = basic_unit(nn.GroupNorm)
MaxUnpool1d = blackbox_module(nn.MaxUnpool1d) SyncBatchNorm = basic_unit(nn.SyncBatchNorm)
MaxUnpool2d = blackbox_module(nn.MaxUnpool2d) Dropout = basic_unit(nn.Dropout)
MaxUnpool3d = blackbox_module(nn.MaxUnpool3d) Dropout2d = basic_unit(nn.Dropout2d)
FractionalMaxPool2d = blackbox_module(nn.FractionalMaxPool2d) Dropout3d = basic_unit(nn.Dropout3d)
FractionalMaxPool3d = blackbox_module(nn.FractionalMaxPool3d) AlphaDropout = basic_unit(nn.AlphaDropout)
LPPool1d = blackbox_module(nn.LPPool1d) FeatureAlphaDropout = basic_unit(nn.FeatureAlphaDropout)
LPPool2d = blackbox_module(nn.LPPool2d) ReflectionPad1d = basic_unit(nn.ReflectionPad1d)
LocalResponseNorm = blackbox_module(nn.LocalResponseNorm) ReflectionPad2d = basic_unit(nn.ReflectionPad2d)
BatchNorm1d = blackbox_module(nn.BatchNorm1d) ReplicationPad2d = basic_unit(nn.ReplicationPad2d)
BatchNorm2d = blackbox_module(nn.BatchNorm2d) ReplicationPad1d = basic_unit(nn.ReplicationPad1d)
BatchNorm3d = blackbox_module(nn.BatchNorm3d) ReplicationPad3d = basic_unit(nn.ReplicationPad3d)
InstanceNorm1d = blackbox_module(nn.InstanceNorm1d) CrossMapLRN2d = basic_unit(nn.CrossMapLRN2d)
InstanceNorm2d = blackbox_module(nn.InstanceNorm2d) Embedding = basic_unit(nn.Embedding)
InstanceNorm3d = blackbox_module(nn.InstanceNorm3d) EmbeddingBag = basic_unit(nn.EmbeddingBag)
LayerNorm = blackbox_module(nn.LayerNorm) RNNBase = basic_unit(nn.RNNBase)
GroupNorm = blackbox_module(nn.GroupNorm) RNN = basic_unit(nn.RNN)
SyncBatchNorm = blackbox_module(nn.SyncBatchNorm) LSTM = basic_unit(nn.LSTM)
Dropout = blackbox_module(nn.Dropout) GRU = basic_unit(nn.GRU)
Dropout2d = blackbox_module(nn.Dropout2d) RNNCellBase = basic_unit(nn.RNNCellBase)
Dropout3d = blackbox_module(nn.Dropout3d) RNNCell = basic_unit(nn.RNNCell)
AlphaDropout = blackbox_module(nn.AlphaDropout) LSTMCell = basic_unit(nn.LSTMCell)
FeatureAlphaDropout = blackbox_module(nn.FeatureAlphaDropout) GRUCell = basic_unit(nn.GRUCell)
ReflectionPad1d = blackbox_module(nn.ReflectionPad1d) PixelShuffle = basic_unit(nn.PixelShuffle)
ReflectionPad2d = blackbox_module(nn.ReflectionPad2d) Upsample = basic_unit(nn.Upsample)
ReplicationPad2d = blackbox_module(nn.ReplicationPad2d) UpsamplingNearest2d = basic_unit(nn.UpsamplingNearest2d)
ReplicationPad1d = blackbox_module(nn.ReplicationPad1d) UpsamplingBilinear2d = basic_unit(nn.UpsamplingBilinear2d)
ReplicationPad3d = blackbox_module(nn.ReplicationPad3d) PairwiseDistance = basic_unit(nn.PairwiseDistance)
CrossMapLRN2d = blackbox_module(nn.CrossMapLRN2d) AdaptiveMaxPool1d = basic_unit(nn.AdaptiveMaxPool1d)
Embedding = blackbox_module(nn.Embedding) AdaptiveMaxPool2d = basic_unit(nn.AdaptiveMaxPool2d)
EmbeddingBag = blackbox_module(nn.EmbeddingBag) AdaptiveMaxPool3d = basic_unit(nn.AdaptiveMaxPool3d)
RNNBase = blackbox_module(nn.RNNBase) AdaptiveAvgPool1d = basic_unit(nn.AdaptiveAvgPool1d)
RNN = blackbox_module(nn.RNN) AdaptiveAvgPool2d = basic_unit(nn.AdaptiveAvgPool2d)
LSTM = blackbox_module(nn.LSTM) AdaptiveAvgPool3d = basic_unit(nn.AdaptiveAvgPool3d)
GRU = blackbox_module(nn.GRU) TripletMarginLoss = basic_unit(nn.TripletMarginLoss)
RNNCellBase = blackbox_module(nn.RNNCellBase) ZeroPad2d = basic_unit(nn.ZeroPad2d)
RNNCell = blackbox_module(nn.RNNCell) ConstantPad1d = basic_unit(nn.ConstantPad1d)
LSTMCell = blackbox_module(nn.LSTMCell) ConstantPad2d = basic_unit(nn.ConstantPad2d)
GRUCell = blackbox_module(nn.GRUCell) ConstantPad3d = basic_unit(nn.ConstantPad3d)
PixelShuffle = blackbox_module(nn.PixelShuffle) Bilinear = basic_unit(nn.Bilinear)
Upsample = blackbox_module(nn.Upsample) CosineSimilarity = basic_unit(nn.CosineSimilarity)
UpsamplingNearest2d = blackbox_module(nn.UpsamplingNearest2d) Unfold = basic_unit(nn.Unfold)
UpsamplingBilinear2d = blackbox_module(nn.UpsamplingBilinear2d) Fold = basic_unit(nn.Fold)
PairwiseDistance = blackbox_module(nn.PairwiseDistance) AdaptiveLogSoftmaxWithLoss = basic_unit(nn.AdaptiveLogSoftmaxWithLoss)
AdaptiveMaxPool1d = blackbox_module(nn.AdaptiveMaxPool1d) TransformerEncoder = basic_unit(nn.TransformerEncoder)
AdaptiveMaxPool2d = blackbox_module(nn.AdaptiveMaxPool2d) TransformerDecoder = basic_unit(nn.TransformerDecoder)
AdaptiveMaxPool3d = blackbox_module(nn.AdaptiveMaxPool3d) TransformerEncoderLayer = basic_unit(nn.TransformerEncoderLayer)
AdaptiveAvgPool1d = blackbox_module(nn.AdaptiveAvgPool1d) TransformerDecoderLayer = basic_unit(nn.TransformerDecoderLayer)
AdaptiveAvgPool2d = blackbox_module(nn.AdaptiveAvgPool2d) Transformer = basic_unit(nn.Transformer)
AdaptiveAvgPool3d = blackbox_module(nn.AdaptiveAvgPool3d) Flatten = basic_unit(nn.Flatten)
TripletMarginLoss = blackbox_module(nn.TripletMarginLoss) Hardsigmoid = basic_unit(nn.Hardsigmoid)
ZeroPad2d = blackbox_module(nn.ZeroPad2d)
ConstantPad1d = blackbox_module(nn.ConstantPad1d)
ConstantPad2d = blackbox_module(nn.ConstantPad2d)
ConstantPad3d = blackbox_module(nn.ConstantPad3d)
Bilinear = blackbox_module(nn.Bilinear)
CosineSimilarity = blackbox_module(nn.CosineSimilarity)
Unfold = blackbox_module(nn.Unfold)
Fold = blackbox_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = blackbox_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = blackbox_module(nn.TransformerEncoder)
TransformerDecoder = blackbox_module(nn.TransformerDecoder)
TransformerEncoderLayer = blackbox_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = blackbox_module(nn.TransformerDecoderLayer)
Transformer = blackbox_module(nn.Transformer)
Flatten = blackbox_module(nn.Flatten)
Hardsigmoid = blackbox_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'): if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = blackbox_module(nn.Hardswish) Hardswish = basic_unit(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'): if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = blackbox_module(nn.SiLU) SiLU = basic_unit(nn.SiLU)
Unflatten = blackbox_module(nn.Unflatten) Unflatten = basic_unit(nn.Unflatten)
TripletMarginWithDistanceLoss = blackbox_module(nn.TripletMarginWithDistanceLoss) TripletMarginWithDistanceLoss = basic_unit(nn.TripletMarginWithDistanceLoss)
from .functional import FunctionalTrainer
from .interface import BaseOneShotTrainer from .interface import BaseOneShotTrainer
...@@ -2,11 +2,6 @@ import abc ...@@ -2,11 +2,6 @@ import abc
from typing import Any from typing import Any
class BaseTrainer(abc.ABC):
# Deprecated class
pass
class BaseOneShotTrainer(abc.ABC): class BaseOneShotTrainer(abc.ABC):
""" """
Build many (possibly all) architectures into a full graph, search (with train) and export the best. Build many (possibly all) architectures into a full graph, search (with train) and export the best.
......
from .base import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from .darts import DartsTrainer from .darts import DartsTrainer
from .enas import EnasTrainer from .enas import EnasTrainer
from .proxyless import ProxylessTrainer from .proxyless import ProxylessTrainer
from .random import RandomTrainer, SinglePathTrainer from .random import SinglePathTrainer, RandomTrainer
from .utils import replace_input_choice, replace_layer_choice
import abc
import functools
import inspect
from typing import Any
import json_tricks
from .utils import get_full_class_name, get_module_name, import_
def get_init_parameters_or_fail(obj, silently=False):
if hasattr(obj, '_init_parameters'):
return obj._init_parameters
elif silently:
return None
else:
raise ValueError(f'Object {obj} needs to be serializable but `_init_parameters` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use serialize or @serialize_cls.')
### This is a patch of json-tricks to make it more useful to us ###
def _serialize_class_instance_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
try: # FIXME: raise error
if hasattr(obj, '__class__'):
return {
'__type__': get_full_class_name(obj.__class__),
'arguments': get_init_parameters_or_fail(obj)
}
except ValueError:
pass
return obj
def _serialize_class_instance_decode(obj):
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
return import_(obj['__type__'])(**obj['arguments'])
return obj
def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type):
return {'__typename__': get_full_class_name(obj, relocate_module=True)}
return obj
def _type_decode(obj):
if isinstance(obj, dict) and '__typename__' in obj:
return import_(obj['__typename__'])
return obj
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_serialize_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_serialize_class_instance_encode, _type_encode])
### End of json-tricks patch ###
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@abc.abstractmethod
def _translate(self) -> Any:
pass
def _create_wrapper_cls(cls, store_init_parameters=True):
class wrapper(cls):
def __init__(self, *args, **kwargs):
if store_init_parameters:
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {}
full_args.update(kwargs)
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value._translate()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value._translate()
self._init_parameters = full_args
else:
self._init_parameters = {}
super().__init__(*args, **kwargs)
wrapper.__module__ = get_module_name(cls)
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
return wrapper
def serialize_cls(cls):
"""
To create an serializable class.
"""
return _create_wrapper_cls(cls)
def transparent_serialize(cls):
"""
Wrap a module but does not record parameters. For internal use only.
"""
return _create_wrapper_cls(cls, store_init_parameters=False)
def serialize(cls, *args, **kwargs):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
self.op = serialize(MyCustomOp, hidden_units=128)
"""
return serialize_cls(cls)(*args, **kwargs)
def basic_unit(cls):
"""
To wrap a module as a basic unit, to stop it from parsing and make it mutate-able.
"""
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
return serialize_cls(cls)
import abc
import functools
import inspect import inspect
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
from pathlib import Path from pathlib import Path
import json_tricks
def import_(target: str, allow_none: bool = False) -> Any: def import_(target: str, allow_none: bool = False) -> Any:
if target is None: if target is None:
...@@ -23,145 +19,6 @@ def version_larger_equal(a: str, b: str) -> bool: ...@@ -23,145 +19,6 @@ def version_larger_equal(a: str, b: str) -> bool:
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.'))) return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
### This is a patch of json-tricks to make it more useful to us ###
def _blackbox_class_instance_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if hasattr(obj, '__class__') and hasattr(obj, '__init_parameters__'):
return {
'__type__': get_full_class_name(obj.__class__),
'arguments': obj.__init_parameters__
}
return obj
def _blackbox_class_instance_decode(obj):
if isinstance(obj, dict) and '__type__' in obj and 'arguments' in obj:
return import_(obj['__type__'])(**obj['arguments'])
return obj
def _type_encode(obj, primitives=False):
assert not primitives, 'Encoding with primitives is not supported.'
if isinstance(obj, type):
return {'__typename__': get_full_class_name(obj, relocate_module=True)}
return obj
def _type_decode(obj):
if isinstance(obj, dict) and '__typename__' in obj:
return import_(obj['__typename__'])
return obj
json_loads = functools.partial(json_tricks.loads, extra_obj_pairs_hooks=[_blackbox_class_instance_decode, _type_decode])
json_dumps = functools.partial(json_tricks.dumps, extra_obj_encoders=[_blackbox_class_instance_encode, _type_encode])
json_load = functools.partial(json_tricks.load, extra_obj_pairs_hooks=[_blackbox_class_instance_decode, _type_decode])
json_dump = functools.partial(json_tricks.dump, extra_obj_encoders=[_blackbox_class_instance_encode, _type_encode])
### End of json-tricks patch ###
_records = {}
def get_records():
global _records
return _records
def clear_records():
global _records
_records = {}
def add_record(key, value):
"""
"""
global _records
if _records is not None:
assert key not in _records, f'{key} already in _records. Conflict: {_records[key]}'
_records[key] = value
def del_record(key):
global _records
if _records is not None:
_records.pop(key, None)
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@abc.abstractmethod
def _translate(self) -> Any:
pass
def _blackbox_cls(cls):
class wrapper(cls):
def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {}
full_args.update(kwargs)
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value._translate()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value._translate()
add_record(id(self), full_args) # for compatibility. Will remove soon.
self.__init_parameters__ = full_args
super().__init__(*args, **kwargs)
def __del__(self):
del_record(id(self))
wrapper.__module__ = _get_module_name(cls)
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
return wrapper
def blackbox(cls, *args, **kwargs):
"""
To create an blackbox instance inline without decorator. For example,
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
return _blackbox_cls(cls)(*args, **kwargs)
def blackbox_module(cls):
"""
Register a module. Use it as a decorator.
"""
return _blackbox_cls(cls)
def register_trainer(cls):
"""
Register a trainer. Use it as a decorator.
"""
return _blackbox_cls(cls)
_last_uid = defaultdict(int) _last_uid = defaultdict(int)
...@@ -170,7 +27,7 @@ def uid(namespace: str = 'default') -> int: ...@@ -170,7 +27,7 @@ def uid(namespace: str = 'default') -> int:
return _last_uid[namespace] return _last_uid[namespace]
def _get_module_name(cls): def get_module_name(cls):
module_name = cls.__module__ module_name = cls.__module__
if module_name == '__main__': if module_name == '__main__':
# infer the module name with inspect # infer the module name with inspect
...@@ -180,7 +37,7 @@ def _get_module_name(cls): ...@@ -180,7 +37,7 @@ def _get_module_name(cls):
main_file_path = Path(inspect.getsourcefile(frm[0])) main_file_path = Path(inspect.getsourcefile(frm[0]))
if main_file_path.parents[0] != Path('.'): if main_file_path.parents[0] != Path('.'):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, ' raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.') f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem module_name = main_file_path.stem
break break
...@@ -195,5 +52,5 @@ def _get_module_name(cls): ...@@ -195,5 +52,5 @@ def _get_module_name(cls):
def get_full_class_name(cls, relocate_module=False): def get_full_class_name(cls, relocate_module=False):
module_name = _get_module_name(cls) if relocate_module else cls.__module__ module_name = get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__ return module_name + '.' + cls.__name__
...@@ -7,9 +7,9 @@ import torch.nn as torch_nn ...@@ -7,9 +7,9 @@ import torch.nn as torch_nn
import ops import ops
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module from nni.retiarii import basic_unit
@blackbox_module @basic_unit
class AuxiliaryHead(nn.Module): class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """ """ Auxiliary head in 2/3 place of network to let the gradient flow well """
......
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import blackbox_module from nni.retiarii import basic_unit
@blackbox_module @basic_unit
class DropPath(nn.Module): class DropPath(nn.Module):
def __init__(self, p=0.): def __init__(self, p=0.):
""" """
...@@ -24,7 +24,7 @@ class DropPath(nn.Module): ...@@ -24,7 +24,7 @@ class DropPath(nn.Module):
return x return x
@blackbox_module @basic_unit
class PoolBN(nn.Module): class PoolBN(nn.Module):
""" """
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`. AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
...@@ -45,7 +45,7 @@ class PoolBN(nn.Module): ...@@ -45,7 +45,7 @@ class PoolBN(nn.Module):
out = self.bn(out) out = self.bn(out)
return out return out
@blackbox_module @basic_unit
class StdConv(nn.Module): class StdConv(nn.Module):
""" """
Standard conv: ReLU - Conv - BN Standard conv: ReLU - Conv - BN
...@@ -61,7 +61,7 @@ class StdConv(nn.Module): ...@@ -61,7 +61,7 @@ class StdConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@blackbox_module @basic_unit
class FacConv(nn.Module): class FacConv(nn.Module):
""" """
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
...@@ -78,7 +78,7 @@ class FacConv(nn.Module): ...@@ -78,7 +78,7 @@ class FacConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@blackbox_module @basic_unit
class DilConv(nn.Module): class DilConv(nn.Module):
""" """
(Dilated) depthwise separable conv. (Dilated) depthwise separable conv.
...@@ -98,7 +98,7 @@ class DilConv(nn.Module): ...@@ -98,7 +98,7 @@ class DilConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@blackbox_module @basic_unit
class SepConv(nn.Module): class SepConv(nn.Module):
""" """
Depthwise separable conv. Depthwise separable conv.
...@@ -114,7 +114,7 @@ class SepConv(nn.Module): ...@@ -114,7 +114,7 @@ class SepConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@blackbox_module @basic_unit
class FactorizedReduce(nn.Module): class FactorizedReduce(nn.Module):
""" """
Reduce feature map size by factorized pointwise (stride=2). Reduce feature map size by factorized pointwise (stride=2).
......
...@@ -4,9 +4,9 @@ import sys ...@@ -4,9 +4,9 @@ import sys
import torch import torch
from pathlib import Path from pathlib import Path
import nni.retiarii.trainer.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.strategy as strategy import nni.retiarii.strategy as strategy
from nni.retiarii import blackbox_module as bm from nni.retiarii import serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
...@@ -27,8 +27,8 @@ if __name__ == '__main__': ...@@ -27,8 +27,8 @@ if __name__ == '__main__':
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]) ])
train_dataset = bm(CIFAR10)(root='data/cifar10', train=True, download=True, transform=train_transform) train_dataset = serialize(CIFAR10, root='data/cifar10', train=True, download=True, transform=train_transform)
test_dataset = bm(CIFAR10)(root='data/cifar10', train=False, download=True, transform=valid_transform) test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100), trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2) max_epochs=1, limit_train_batches=0.2)
......
...@@ -9,7 +9,7 @@ from torchvision import transforms ...@@ -9,7 +9,7 @@ from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from nni.retiarii.experiment.pytorch import RetiariiExperiment from nni.retiarii.experiment.pytorch import RetiariiExperiment
from nni.retiarii.trainer.pytorch import DartsTrainer from nni.retiarii.oneshot.pytorch import DartsTrainer
from darts_model import CNN from darts_model import CNN
......
from nni.retiarii import blackbox_module from nni.retiarii import basic_unit
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
import warnings import warnings
...@@ -148,7 +148,7 @@ class MNASNet(nn.Module): ...@@ -148,7 +148,7 @@ class MNASNet(nn.Module):
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios): # zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides): for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides):
# TODO: restrict that "choose" can only be used within mutator # TODO: restrict that "choose" can only be used within mutator
ph = nn.Placeholder(label=f'mutable_{count}', related_info={ ph = nn.Placeholder(label=f'mutable_{count}', **{
'kernel_size_options': [1, 3, 5], 'kernel_size_options': [1, 3, 5],
'n_layer_options': [1, 2, 3, 4], 'n_layer_options': [1, 2, 3, 4],
'op_type_options': ['__mutated__.base_mnasnet.RegularConv', 'op_type_options': ['__mutated__.base_mnasnet.RegularConv',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment