Unverified Commit ae50ed14 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Refactor wrap module as "blackbox_module" (#3238)

parent 15da19d3
......@@ -2,4 +2,4 @@ from .operation import Operation
from .graph import *
from .execution import *
from .mutator import *
from .utils import register_module
\ No newline at end of file
from .utils import blackbox, blackbox_module, register_trainer
......@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: %s', str(edges))
_logger.debug('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
_logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
......
This diff is collapsed.
......@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__)
OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer)
@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
......@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment):
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args)
base_model_ir = convert_to_graph(script_module, self.base_model)
assert id(self.trainer) in self.recorded_module_args
trainer_config = self.recorded_module_args[id(self.trainer)]
base_model.apply_trainer(trainer_config['modulename'], trainer_config['args'])
recorded_module_args = get_records()
if id(self.trainer) not in recorded_module_args:
raise KeyError('Your trainer is not found in registered classes. You might have forgotten to \
register your customized trainer with @register_trainer decorator.')
trainer_config = recorded_module_args[id(self.trainer)]
base_model_ir.apply_trainer(trainer_config['modulename'], trainer_config['args'])
# handle inline mutations
mutators = self._process_inline_mutation(base_model)
mutators = self._process_inline_mutation(base_model_ir)
if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \
do not use mutators when you use LayerChoice/InputChoice')
......@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment):
self.applied_mutators = mutators
_logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators)).start()
Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)).start()
_logger.info('Strategy started!')
def start(self, port: int = 8080, debug: bool = False) -> None:
......
import inspect
import logging
from typing import Any, List
import torch
import torch.nn as nn
from ...utils import add_record, version_larger_equal
from ...utils import add_record, blackbox_module, uid, version_larger_equal
_logger = logging.getLogger(__name__)
......@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'):
if version_larger_equal(torch.__version__, '1.7.0'):
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss'])
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'ChannelShuffle'
class LayerChoice(nn.Module):
def __init__(self, op_candidates, reduction=None, return_mask=False, key=None):
super(LayerChoice, self).__init__()
self.candidate_ops = op_candidates
self.label = key
self.key = key # deprecated, for backward compatibility
self.op_candidates = op_candidates
self.label = key if key is not None else f'layerchoice_{uid()}'
self.key = self.label # deprecated, for backward compatibility
for i, module in enumerate(op_candidates): # deprecated, for backward compatibility
self.add_module(str(i), module)
if reduction or return_mask:
......@@ -66,8 +62,8 @@ class InputChoice(nn.Module):
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction = reduction
self.label = key
self.key = key # deprecated, for backward compatibility
self.label = key if key is not None else f'inputchoice_{uid()}'
self.key = self.label # deprecated, for backward compatibility
if choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
......@@ -101,6 +97,7 @@ class Placeholder(nn.Module):
class ChosenInputs(nn.Module):
"""
"""
def __init__(self, chosen: List[int], reduction: str):
super().__init__()
self.chosen = chosen
......@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
Module = nn.Module
class Sequential(nn.Sequential):
......@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList):
super(ModuleList, self).__init__(*args)
def wrap_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear)
Conv1d = wrap_module(nn.Conv1d)
Conv2d = wrap_module(nn.Conv2d)
Conv3d = wrap_module(nn.Conv3d)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d)
Threshold = wrap_module(nn.Threshold)
ReLU = wrap_module(nn.ReLU)
Hardtanh = wrap_module(nn.Hardtanh)
ReLU6 = wrap_module(nn.ReLU6)
Sigmoid = wrap_module(nn.Sigmoid)
Tanh = wrap_module(nn.Tanh)
Softmax = wrap_module(nn.Softmax)
Softmax2d = wrap_module(nn.Softmax2d)
LogSoftmax = wrap_module(nn.LogSoftmax)
ELU = wrap_module(nn.ELU)
SELU = wrap_module(nn.SELU)
CELU = wrap_module(nn.CELU)
GLU = wrap_module(nn.GLU)
GELU = wrap_module(nn.GELU)
Hardshrink = wrap_module(nn.Hardshrink)
LeakyReLU = wrap_module(nn.LeakyReLU)
LogSigmoid = wrap_module(nn.LogSigmoid)
Softplus = wrap_module(nn.Softplus)
Softshrink = wrap_module(nn.Softshrink)
MultiheadAttention = wrap_module(nn.MultiheadAttention)
PReLU = wrap_module(nn.PReLU)
Softsign = wrap_module(nn.Softsign)
Softmin = wrap_module(nn.Softmin)
Tanhshrink = wrap_module(nn.Tanhshrink)
RReLU = wrap_module(nn.RReLU)
AvgPool1d = wrap_module(nn.AvgPool1d)
AvgPool2d = wrap_module(nn.AvgPool2d)
AvgPool3d = wrap_module(nn.AvgPool3d)
MaxPool1d = wrap_module(nn.MaxPool1d)
MaxPool2d = wrap_module(nn.MaxPool2d)
MaxPool3d = wrap_module(nn.MaxPool3d)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d)
LPPool1d = wrap_module(nn.LPPool1d)
LPPool2d = wrap_module(nn.LPPool2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm)
BatchNorm1d = wrap_module(nn.BatchNorm1d)
BatchNorm2d = wrap_module(nn.BatchNorm2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d)
LayerNorm = wrap_module(nn.LayerNorm)
GroupNorm = wrap_module(nn.GroupNorm)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm)
Dropout = wrap_module(nn.Dropout)
Dropout2d = wrap_module(nn.Dropout2d)
Dropout3d = wrap_module(nn.Dropout3d)
AlphaDropout = wrap_module(nn.AlphaDropout)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d)
Embedding = wrap_module(nn.Embedding)
EmbeddingBag = wrap_module(nn.EmbeddingBag)
RNNBase = wrap_module(nn.RNNBase)
RNN = wrap_module(nn.RNN)
LSTM = wrap_module(nn.LSTM)
GRU = wrap_module(nn.GRU)
RNNCellBase = wrap_module(nn.RNNCellBase)
RNNCell = wrap_module(nn.RNNCell)
LSTMCell = wrap_module(nn.LSTMCell)
GRUCell = wrap_module(nn.GRUCell)
PixelShuffle = wrap_module(nn.PixelShuffle)
Upsample = wrap_module(nn.Upsample)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = wrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
Flatten = wrap_module(nn.Flatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
Identity = blackbox_module(nn.Identity)
Linear = blackbox_module(nn.Linear)
Conv1d = blackbox_module(nn.Conv1d)
Conv2d = blackbox_module(nn.Conv2d)
Conv3d = blackbox_module(nn.Conv3d)
ConvTranspose1d = blackbox_module(nn.ConvTranspose1d)
ConvTranspose2d = blackbox_module(nn.ConvTranspose2d)
ConvTranspose3d = blackbox_module(nn.ConvTranspose3d)
Threshold = blackbox_module(nn.Threshold)
ReLU = blackbox_module(nn.ReLU)
Hardtanh = blackbox_module(nn.Hardtanh)
ReLU6 = blackbox_module(nn.ReLU6)
Sigmoid = blackbox_module(nn.Sigmoid)
Tanh = blackbox_module(nn.Tanh)
Softmax = blackbox_module(nn.Softmax)
Softmax2d = blackbox_module(nn.Softmax2d)
LogSoftmax = blackbox_module(nn.LogSoftmax)
ELU = blackbox_module(nn.ELU)
SELU = blackbox_module(nn.SELU)
CELU = blackbox_module(nn.CELU)
GLU = blackbox_module(nn.GLU)
GELU = blackbox_module(nn.GELU)
Hardshrink = blackbox_module(nn.Hardshrink)
LeakyReLU = blackbox_module(nn.LeakyReLU)
LogSigmoid = blackbox_module(nn.LogSigmoid)
Softplus = blackbox_module(nn.Softplus)
Softshrink = blackbox_module(nn.Softshrink)
MultiheadAttention = blackbox_module(nn.MultiheadAttention)
PReLU = blackbox_module(nn.PReLU)
Softsign = blackbox_module(nn.Softsign)
Softmin = blackbox_module(nn.Softmin)
Tanhshrink = blackbox_module(nn.Tanhshrink)
RReLU = blackbox_module(nn.RReLU)
AvgPool1d = blackbox_module(nn.AvgPool1d)
AvgPool2d = blackbox_module(nn.AvgPool2d)
AvgPool3d = blackbox_module(nn.AvgPool3d)
MaxPool1d = blackbox_module(nn.MaxPool1d)
MaxPool2d = blackbox_module(nn.MaxPool2d)
MaxPool3d = blackbox_module(nn.MaxPool3d)
MaxUnpool1d = blackbox_module(nn.MaxUnpool1d)
MaxUnpool2d = blackbox_module(nn.MaxUnpool2d)
MaxUnpool3d = blackbox_module(nn.MaxUnpool3d)
FractionalMaxPool2d = blackbox_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = blackbox_module(nn.FractionalMaxPool3d)
LPPool1d = blackbox_module(nn.LPPool1d)
LPPool2d = blackbox_module(nn.LPPool2d)
LocalResponseNorm = blackbox_module(nn.LocalResponseNorm)
BatchNorm1d = blackbox_module(nn.BatchNorm1d)
BatchNorm2d = blackbox_module(nn.BatchNorm2d)
BatchNorm3d = blackbox_module(nn.BatchNorm3d)
InstanceNorm1d = blackbox_module(nn.InstanceNorm1d)
InstanceNorm2d = blackbox_module(nn.InstanceNorm2d)
InstanceNorm3d = blackbox_module(nn.InstanceNorm3d)
LayerNorm = blackbox_module(nn.LayerNorm)
GroupNorm = blackbox_module(nn.GroupNorm)
SyncBatchNorm = blackbox_module(nn.SyncBatchNorm)
Dropout = blackbox_module(nn.Dropout)
Dropout2d = blackbox_module(nn.Dropout2d)
Dropout3d = blackbox_module(nn.Dropout3d)
AlphaDropout = blackbox_module(nn.AlphaDropout)
FeatureAlphaDropout = blackbox_module(nn.FeatureAlphaDropout)
ReflectionPad1d = blackbox_module(nn.ReflectionPad1d)
ReflectionPad2d = blackbox_module(nn.ReflectionPad2d)
ReplicationPad2d = blackbox_module(nn.ReplicationPad2d)
ReplicationPad1d = blackbox_module(nn.ReplicationPad1d)
ReplicationPad3d = blackbox_module(nn.ReplicationPad3d)
CrossMapLRN2d = blackbox_module(nn.CrossMapLRN2d)
Embedding = blackbox_module(nn.Embedding)
EmbeddingBag = blackbox_module(nn.EmbeddingBag)
RNNBase = blackbox_module(nn.RNNBase)
RNN = blackbox_module(nn.RNN)
LSTM = blackbox_module(nn.LSTM)
GRU = blackbox_module(nn.GRU)
RNNCellBase = blackbox_module(nn.RNNCellBase)
RNNCell = blackbox_module(nn.RNNCell)
LSTMCell = blackbox_module(nn.LSTMCell)
GRUCell = blackbox_module(nn.GRUCell)
PixelShuffle = blackbox_module(nn.PixelShuffle)
Upsample = blackbox_module(nn.Upsample)
UpsamplingNearest2d = blackbox_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = blackbox_module(nn.UpsamplingBilinear2d)
PairwiseDistance = blackbox_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = blackbox_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = blackbox_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = blackbox_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = blackbox_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = blackbox_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = blackbox_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = blackbox_module(nn.TripletMarginLoss)
ZeroPad2d = blackbox_module(nn.ZeroPad2d)
ConstantPad1d = blackbox_module(nn.ConstantPad1d)
ConstantPad2d = blackbox_module(nn.ConstantPad2d)
ConstantPad3d = blackbox_module(nn.ConstantPad3d)
Bilinear = blackbox_module(nn.Bilinear)
CosineSimilarity = blackbox_module(nn.CosineSimilarity)
Unfold = blackbox_module(nn.Unfold)
Fold = blackbox_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = blackbox_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = blackbox_module(nn.TransformerEncoder)
TransformerDecoder = blackbox_module(nn.TransformerDecoder)
TransformerEncoderLayer = blackbox_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = blackbox_module(nn.TransformerDecoderLayer)
Transformer = blackbox_module(nn.Transformer)
Flatten = blackbox_module(nn.Flatten)
Hardsigmoid = blackbox_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = wrap_module(nn.Hardswish)
Hardswish = blackbox_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = wrap_module(nn.SiLU)
Unflatten = wrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
\ No newline at end of file
SiLU = blackbox_module(nn.SiLU)
Unflatten = blackbox_module(nn.Unflatten)
TripletMarginWithDistanceLoss = blackbox_module(nn.TripletMarginWithDistanceLoss)
......@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
return None
@register_trainer()
@register_trainer
class PyTorchImageClassificationTrainer(BaseTrainer):
"""
Image classification trainer for PyTorch.
......@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super(PyTorchImageClassificationTrainer, self).__init__()
super().__init__()
self._use_cuda = torch.cuda.is_available()
self.model = model
if self._use_cuda:
......
import inspect
import warnings
from collections import defaultdict
from typing import Any
......@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any:
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
def version_larger_equal(a: str, b: str) -> bool:
# TODO: refactor later
a = a.split('+')[0]
b = b.split('+')[0]
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
_records = {}
......@@ -29,73 +32,87 @@ def add_record(key, value):
"""
global _records
if _records is not None:
#assert key not in _records, '{} already in _records'.format(key)
assert key not in _records, '{} already in _records'.format(key)
_records[key] = value
def _register_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
def del_record(key):
global _records
if _records is not None:
_records.pop(key, None)
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
def _blackbox_cls(cls, module_name, register_format=None):
class wrapper(cls):
def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls).parameters.keys())
full_args = {}
full_args.update(kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# eject un-serializable arguments
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases.
if not isinstance(full_args[k], (int, float, str, dict, list)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
This is not supported. You can ignore this warning if you are passing the model to trainer.')
full_args.pop(k)
def register_module():
"""
Register a module.
"""
# use it as a decorator: @register_module()
def _register(cls):
m = _register_module(
original_class=cls)
return m
if register_format == 'args':
add_record(id(self), full_args)
elif register_format == 'full':
full_class_name = cls.__module__ + '.' + cls.__name__
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
return _register
super().__init__(*args, **kwargs)
def __del__(self):
del_record(id(self))
def _register_trainer(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
# using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
# instead of simply putting torch.nn or etc.
wrapper.__module__ = module_name
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
full_class_name = original_class.__module__ + '.' + original_class.__name__
return wrapper
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
# TODO: support both pytorch and tensorflow
from .nn.pytorch import Module
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__
def blackbox(cls, *args, **kwargs):
"""
To create an blackbox instance inline without decorator. For example,
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')(*args, **kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def blackbox_module(cls):
"""
Register a module. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')
def register_trainer():
def _register(cls):
m = _register_trainer(
original_class=cls)
return m
return _register
def register_trainer(cls):
"""
Register a trainer. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'full')
_last_uid = defaultdict(int)
......
......@@ -5,6 +5,7 @@ tuner_result.txt
assessor_result.txt
_generated_model.py
_generated_model_*.py
data
generated
......@@ -7,9 +7,9 @@ import torch.nn as torch_nn
import ops
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
from nni.retiarii import blackbox_module
@blackbox_module
class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
......@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module):
logits = self.linear(out)
return logits
@register_module()
class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
......@@ -66,7 +65,6 @@ class Node(nn.Module):
#out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)
@register_module()
class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
......@@ -100,7 +98,6 @@ class Cell(nn.Module):
output = torch.cat(new_tensors, dim=1)
return output
@register_module()
class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
......
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
from nni.retiarii import blackbox_module
@register_module()
@blackbox_module
class DropPath(nn.Module):
def __init__(self, p=0.):
"""
......@@ -12,7 +12,7 @@ class DropPath(nn.Module):
p : float
Probability of an path to be zeroed.
"""
super(DropPath, self).__init__()
super().__init__()
self.p = p
def forward(self, x):
......@@ -24,13 +24,13 @@ class DropPath(nn.Module):
return x
@register_module()
@blackbox_module
class PoolBN(nn.Module):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super(PoolBN, self).__init__()
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
......@@ -45,13 +45,13 @@ class PoolBN(nn.Module):
out = self.bn(out)
return out
@register_module()
@blackbox_module
class StdConv(nn.Module):
"""
Standard conv: ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(StdConv, self).__init__()
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
......@@ -61,13 +61,13 @@ class StdConv(nn.Module):
def forward(self, x):
return self.net(x)
@register_module()
@blackbox_module
class FacConv(nn.Module):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super(FacConv, self).__init__()
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
......@@ -78,7 +78,7 @@ class FacConv(nn.Module):
def forward(self, x):
return self.net(x)
@register_module()
@blackbox_module
class DilConv(nn.Module):
"""
(Dilated) depthwise separable conv.
......@@ -86,7 +86,7 @@ class DilConv(nn.Module):
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__()
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
......@@ -98,14 +98,14 @@ class DilConv(nn.Module):
def forward(self, x):
return self.net(x)
@register_module()
@blackbox_module
class SepConv(nn.Module):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
......@@ -114,13 +114,13 @@ class SepConv(nn.Module):
def forward(self, x):
return self.net(x)
@register_module()
@blackbox_module
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
......
......@@ -13,10 +13,10 @@ from darts_model import CNN
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
#simple_startegy = TPEStrategy()
simple_startegy = RandomStrategy()
......@@ -31,4 +31,4 @@ if __name__ == '__main__':
exp_config.training_service.use_active_gpu = True
exp_config.training_service.gpu_indices = [1, 2]
exp.run(exp_config, 8081, debug=True)
exp.run(exp_config, 8081)
......@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0):
valid_transform = transforms.Compose(normalize)
if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
dataset_train = CIFAR10(root="./data/cifar10", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data/cifar10", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
......
from nni.retiarii import blackbox_module
import nni.retiarii.nn.pytorch as nn
import warnings
import torch
......@@ -8,8 +10,6 @@ import torch.nn.functional as F
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2]))
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
......@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module):
def forward(self, x):
return self.net(x) + x
class _InvertedResidual(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, skip, bn_momentum=0.1):
......@@ -110,7 +111,7 @@ def _get_depths(depths, alpha):
rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
@register_module()
class MNASNet(nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
......@@ -127,7 +128,7 @@ class MNASNet(nn.Module):
def __init__(self, alpha, depths, convops, kernel_sizes, num_layers,
skips, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__()
super().__init__()
assert alpha > 0.0
assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7
self.alpha = alpha
......@@ -143,22 +144,22 @@ class MNASNet(nn.Module):
nn.ReLU(inplace=True),
]
count = 0
#for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
# for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides):
# TODO: restrict that "choose" can only be used within mutator
ph = nn.Placeholder(label=f'mutable_{count}', related_info={
'kernel_size_options': [1, 3, 5],
'n_layer_options': [1, 2, 3, 4],
'op_type_options': ['__mutated__.base_mnasnet.RegularConv',
'__mutated__.base_mnasnet.DepthwiseConv',
'__mutated__.base_mnasnet.MobileConv'],
#'se_ratio_options': [0, 0.25],
'skip_options': ['identity', 'no'],
'n_filter_options': [int(filter_size*x) for x in [0.75, 1.0, 1.25]],
'exp_ratio': exp_ratio,
'stride': stride,
'in_ch': depths[0] if count == 0 else None
'kernel_size_options': [1, 3, 5],
'n_layer_options': [1, 2, 3, 4],
'op_type_options': ['__mutated__.base_mnasnet.RegularConv',
'__mutated__.base_mnasnet.DepthwiseConv',
'__mutated__.base_mnasnet.MobileConv'],
# 'se_ratio_options': [0, 0.25],
'skip_options': ['identity', 'no'],
'n_filter_options': [int(filter_size*x) for x in [0.75, 1.0, 1.25]],
'exp_ratio': exp_ratio,
'stride': stride,
'in_ch': depths[0] if count == 0 else None
})
layers.append(ph)
'''if conv == "mconv":
......@@ -185,7 +186,7 @@ class MNASNet(nn.Module):
#self.for_test = 10
def forward(self, x):
#if self.for_test == 10:
# if self.for_test == 10:
x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3])
......@@ -196,7 +197,7 @@ class MNASNet(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch_nn.init.kaiming_normal_(m.weight, mode="fan_out",
nonlinearity="relu")
nonlinearity="relu")
if m.bias is not None:
torch_nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
......@@ -204,16 +205,18 @@ class MNASNet(nn.Module):
torch_nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
torch_nn.init.kaiming_uniform_(m.weight, mode="fan_out",
nonlinearity="sigmoid")
nonlinearity="sigmoid")
torch_nn.init.zeros_(m.bias)
def test_model(model):
model(torch.randn(2, 3, 224, 224))
#====================definition of candidate op classes
# ====================definition of candidate op classes
BN_MOMENTUM = 1 - 0.9997
class RegularConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__()
......@@ -234,6 +237,7 @@ class RegularConv(nn.Module):
out = out + x
return out
class DepthwiseConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__()
......@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module):
out = out + x
return out
class MobileConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__()
......@@ -274,7 +279,7 @@ class MobileConv(nn.Module):
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
# Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding= (kernel_size - 1) // 2,
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=(kernel_size - 1) // 2,
stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True),
......@@ -288,5 +293,6 @@ class MobileConv(nn.Module):
out = out + x
return out
# mnasnet0_5
ir_module = _InvertedResidual(16, 16, 3, 1, 1, True)
\ No newline at end of file
ir_module = _InvertedResidual(16, 16, 3, 1, 1, True)
......@@ -19,12 +19,12 @@ if __name__ == '__main__':
_DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1]
base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES,
_DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
_DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
# new interface
applied_mutators = []
......@@ -41,4 +41,4 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081, debug=True)
exp.run(exp_config, 8081)
import random
import nni.retiarii.nn.pytorch as nn
import torch.nn.functional as F
from nni.retiarii.experiment import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.strategies import RandomStrategy
from nni.retiarii.trainer import PyTorchImageClassificationTrainer
class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice([
nn.Linear(4*4*50, hidden_size),
nn.Linear(4*4*50, hidden_size, bias=False)
])
self.fc2 = nn.Linear(hidden_size, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
if __name__ == '__main__':
base_model = Net(128)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="MNIST",
dataset_kwargs={"root": "data/mnist", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
simple_startegy = RandomStrategy()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081 + random.randint(0, 100))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment