Unverified Commit ae50ed14 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Refactor wrap module as "blackbox_module" (#3238)

parent 15da19d3
...@@ -2,4 +2,4 @@ from .operation import Operation ...@@ -2,4 +2,4 @@ from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
from .mutator import * from .mutator import *
from .utils import register_module from .utils import blackbox, blackbox_module, register_trainer
\ No newline at end of file
...@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str: ...@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
def _sorted_incoming_edges(node: Node) -> List[Edge]: def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node] edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: %s', str(edges)) _logger.debug('sorted_incoming_edges: %s', str(edges))
if not edges: if not edges:
return [] return []
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges])) _logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges): if all(edge.tail_slot is None for edge in edges):
return edges return edges
if all(isinstance(edge.tail_slot, int) for edge in edges): if all(isinstance(edge.tail_slot, int) for edge in edges):
......
This diff is collapsed.
...@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__) ...@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__)
OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer) OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer)
@dataclass(init=False) @dataclass(init=False)
class RetiariiExeConfig(ConfigBase): class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None experiment_name: Optional[str] = None
...@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment): ...@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment):
except Exception as e: except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e raise e
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args) base_model_ir = convert_to_graph(script_module, self.base_model)
assert id(self.trainer) in self.recorded_module_args recorded_module_args = get_records()
trainer_config = self.recorded_module_args[id(self.trainer)] if id(self.trainer) not in recorded_module_args:
base_model.apply_trainer(trainer_config['modulename'], trainer_config['args']) raise KeyError('Your trainer is not found in registered classes. You might have forgotten to \
register your customized trainer with @register_trainer decorator.')
trainer_config = recorded_module_args[id(self.trainer)]
base_model_ir.apply_trainer(trainer_config['modulename'], trainer_config['args'])
# handle inline mutations # handle inline mutations
mutators = self._process_inline_mutation(base_model) mutators = self._process_inline_mutation(base_model_ir)
if mutators is not None and self.applied_mutators: if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \ raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \
do not use mutators when you use LayerChoice/InputChoice') do not use mutators when you use LayerChoice/InputChoice')
...@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment): ...@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment):
self.applied_mutators = mutators self.applied_mutators = mutators
_logger.info('Starting strategy...') _logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators)).start() Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)).start()
_logger.info('Strategy started!') _logger.info('Strategy started!')
def start(self, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False) -> None:
......
import inspect
import logging import logging
from typing import Any, List from typing import Any, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import add_record, version_larger_equal from ...utils import add_record, blackbox_module, uid, version_larger_equal
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'): ...@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'):
if version_larger_equal(torch.__version__, '1.7.0'): if version_larger_equal(torch.__version__, '1.7.0'):
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss']) __all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss'])
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'ChannelShuffle'
class LayerChoice(nn.Module): class LayerChoice(nn.Module):
def __init__(self, op_candidates, reduction=None, return_mask=False, key=None): def __init__(self, op_candidates, reduction=None, return_mask=False, key=None):
super(LayerChoice, self).__init__() super(LayerChoice, self).__init__()
self.candidate_ops = op_candidates self.op_candidates = op_candidates
self.label = key self.label = key if key is not None else f'layerchoice_{uid()}'
self.key = key # deprecated, for backward compatibility self.key = self.label # deprecated, for backward compatibility
for i, module in enumerate(op_candidates): # deprecated, for backward compatibility for i, module in enumerate(op_candidates): # deprecated, for backward compatibility
self.add_module(str(i), module) self.add_module(str(i), module)
if reduction or return_mask: if reduction or return_mask:
...@@ -66,8 +62,8 @@ class InputChoice(nn.Module): ...@@ -66,8 +62,8 @@ class InputChoice(nn.Module):
self.n_candidates = n_candidates self.n_candidates = n_candidates
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction self.reduction = reduction
self.label = key self.label = key if key is not None else f'inputchoice_{uid()}'
self.key = key # deprecated, for backward compatibility self.key = self.label # deprecated, for backward compatibility
if choose_from or return_mask: if choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!') _logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
...@@ -101,6 +97,7 @@ class Placeholder(nn.Module): ...@@ -101,6 +97,7 @@ class Placeholder(nn.Module):
class ChosenInputs(nn.Module): class ChosenInputs(nn.Module):
""" """
""" """
def __init__(self, chosen: List[int], reduction: str): def __init__(self, chosen: List[int], reduction: str):
super().__init__() super().__init__()
self.chosen = chosen self.chosen = chosen
...@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module): ...@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules # the following are pytorch modules
class Module(nn.Module): Module = nn.Module
def __init__(self):
super(Module, self).__init__()
class Sequential(nn.Sequential): class Sequential(nn.Sequential):
...@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList): ...@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList):
super(ModuleList, self).__init__(*args) super(ModuleList, self).__init__(*args)
def wrap_module(original_class): Identity = blackbox_module(nn.Identity)
orig_init = original_class.__init__ Linear = blackbox_module(nn.Linear)
argname_list = list(inspect.signature(original_class).parameters.keys()) Conv1d = blackbox_module(nn.Conv1d)
# Make copy of original __init__, so we can call it without recursion Conv2d = blackbox_module(nn.Conv2d)
Conv3d = blackbox_module(nn.Conv3d)
def __init__(self, *args, **kws): ConvTranspose1d = blackbox_module(nn.ConvTranspose1d)
full_args = {} ConvTranspose2d = blackbox_module(nn.ConvTranspose2d)
full_args.update(kws) ConvTranspose3d = blackbox_module(nn.ConvTranspose3d)
for i, arg in enumerate(args): Threshold = blackbox_module(nn.Threshold)
full_args[argname_list[i]] = arg ReLU = blackbox_module(nn.ReLU)
add_record(id(self), full_args) Hardtanh = blackbox_module(nn.Hardtanh)
ReLU6 = blackbox_module(nn.ReLU6)
orig_init(self, *args, **kws) # Call the original __init__ Sigmoid = blackbox_module(nn.Sigmoid)
Tanh = blackbox_module(nn.Tanh)
original_class.__init__ = __init__ # Set the class' __init__ to the new one Softmax = blackbox_module(nn.Softmax)
return original_class Softmax2d = blackbox_module(nn.Softmax2d)
LogSoftmax = blackbox_module(nn.LogSoftmax)
ELU = blackbox_module(nn.ELU)
Identity = wrap_module(nn.Identity) SELU = blackbox_module(nn.SELU)
Linear = wrap_module(nn.Linear) CELU = blackbox_module(nn.CELU)
Conv1d = wrap_module(nn.Conv1d) GLU = blackbox_module(nn.GLU)
Conv2d = wrap_module(nn.Conv2d) GELU = blackbox_module(nn.GELU)
Conv3d = wrap_module(nn.Conv3d) Hardshrink = blackbox_module(nn.Hardshrink)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d) LeakyReLU = blackbox_module(nn.LeakyReLU)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d) LogSigmoid = blackbox_module(nn.LogSigmoid)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d) Softplus = blackbox_module(nn.Softplus)
Threshold = wrap_module(nn.Threshold) Softshrink = blackbox_module(nn.Softshrink)
ReLU = wrap_module(nn.ReLU) MultiheadAttention = blackbox_module(nn.MultiheadAttention)
Hardtanh = wrap_module(nn.Hardtanh) PReLU = blackbox_module(nn.PReLU)
ReLU6 = wrap_module(nn.ReLU6) Softsign = blackbox_module(nn.Softsign)
Sigmoid = wrap_module(nn.Sigmoid) Softmin = blackbox_module(nn.Softmin)
Tanh = wrap_module(nn.Tanh) Tanhshrink = blackbox_module(nn.Tanhshrink)
Softmax = wrap_module(nn.Softmax) RReLU = blackbox_module(nn.RReLU)
Softmax2d = wrap_module(nn.Softmax2d) AvgPool1d = blackbox_module(nn.AvgPool1d)
LogSoftmax = wrap_module(nn.LogSoftmax) AvgPool2d = blackbox_module(nn.AvgPool2d)
ELU = wrap_module(nn.ELU) AvgPool3d = blackbox_module(nn.AvgPool3d)
SELU = wrap_module(nn.SELU) MaxPool1d = blackbox_module(nn.MaxPool1d)
CELU = wrap_module(nn.CELU) MaxPool2d = blackbox_module(nn.MaxPool2d)
GLU = wrap_module(nn.GLU) MaxPool3d = blackbox_module(nn.MaxPool3d)
GELU = wrap_module(nn.GELU) MaxUnpool1d = blackbox_module(nn.MaxUnpool1d)
Hardshrink = wrap_module(nn.Hardshrink) MaxUnpool2d = blackbox_module(nn.MaxUnpool2d)
LeakyReLU = wrap_module(nn.LeakyReLU) MaxUnpool3d = blackbox_module(nn.MaxUnpool3d)
LogSigmoid = wrap_module(nn.LogSigmoid) FractionalMaxPool2d = blackbox_module(nn.FractionalMaxPool2d)
Softplus = wrap_module(nn.Softplus) FractionalMaxPool3d = blackbox_module(nn.FractionalMaxPool3d)
Softshrink = wrap_module(nn.Softshrink) LPPool1d = blackbox_module(nn.LPPool1d)
MultiheadAttention = wrap_module(nn.MultiheadAttention) LPPool2d = blackbox_module(nn.LPPool2d)
PReLU = wrap_module(nn.PReLU) LocalResponseNorm = blackbox_module(nn.LocalResponseNorm)
Softsign = wrap_module(nn.Softsign) BatchNorm1d = blackbox_module(nn.BatchNorm1d)
Softmin = wrap_module(nn.Softmin) BatchNorm2d = blackbox_module(nn.BatchNorm2d)
Tanhshrink = wrap_module(nn.Tanhshrink) BatchNorm3d = blackbox_module(nn.BatchNorm3d)
RReLU = wrap_module(nn.RReLU) InstanceNorm1d = blackbox_module(nn.InstanceNorm1d)
AvgPool1d = wrap_module(nn.AvgPool1d) InstanceNorm2d = blackbox_module(nn.InstanceNorm2d)
AvgPool2d = wrap_module(nn.AvgPool2d) InstanceNorm3d = blackbox_module(nn.InstanceNorm3d)
AvgPool3d = wrap_module(nn.AvgPool3d) LayerNorm = blackbox_module(nn.LayerNorm)
MaxPool1d = wrap_module(nn.MaxPool1d) GroupNorm = blackbox_module(nn.GroupNorm)
MaxPool2d = wrap_module(nn.MaxPool2d) SyncBatchNorm = blackbox_module(nn.SyncBatchNorm)
MaxPool3d = wrap_module(nn.MaxPool3d) Dropout = blackbox_module(nn.Dropout)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d) Dropout2d = blackbox_module(nn.Dropout2d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d) Dropout3d = blackbox_module(nn.Dropout3d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d) AlphaDropout = blackbox_module(nn.AlphaDropout)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d) FeatureAlphaDropout = blackbox_module(nn.FeatureAlphaDropout)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d) ReflectionPad1d = blackbox_module(nn.ReflectionPad1d)
LPPool1d = wrap_module(nn.LPPool1d) ReflectionPad2d = blackbox_module(nn.ReflectionPad2d)
LPPool2d = wrap_module(nn.LPPool2d) ReplicationPad2d = blackbox_module(nn.ReplicationPad2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm) ReplicationPad1d = blackbox_module(nn.ReplicationPad1d)
BatchNorm1d = wrap_module(nn.BatchNorm1d) ReplicationPad3d = blackbox_module(nn.ReplicationPad3d)
BatchNorm2d = wrap_module(nn.BatchNorm2d) CrossMapLRN2d = blackbox_module(nn.CrossMapLRN2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d) Embedding = blackbox_module(nn.Embedding)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d) EmbeddingBag = blackbox_module(nn.EmbeddingBag)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d) RNNBase = blackbox_module(nn.RNNBase)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d) RNN = blackbox_module(nn.RNN)
LayerNorm = wrap_module(nn.LayerNorm) LSTM = blackbox_module(nn.LSTM)
GroupNorm = wrap_module(nn.GroupNorm) GRU = blackbox_module(nn.GRU)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm) RNNCellBase = blackbox_module(nn.RNNCellBase)
Dropout = wrap_module(nn.Dropout) RNNCell = blackbox_module(nn.RNNCell)
Dropout2d = wrap_module(nn.Dropout2d) LSTMCell = blackbox_module(nn.LSTMCell)
Dropout3d = wrap_module(nn.Dropout3d) GRUCell = blackbox_module(nn.GRUCell)
AlphaDropout = wrap_module(nn.AlphaDropout) PixelShuffle = blackbox_module(nn.PixelShuffle)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout) Upsample = blackbox_module(nn.Upsample)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d) UpsamplingNearest2d = blackbox_module(nn.UpsamplingNearest2d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d) UpsamplingBilinear2d = blackbox_module(nn.UpsamplingBilinear2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d) PairwiseDistance = blackbox_module(nn.PairwiseDistance)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d) AdaptiveMaxPool1d = blackbox_module(nn.AdaptiveMaxPool1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d) AdaptiveMaxPool2d = blackbox_module(nn.AdaptiveMaxPool2d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d) AdaptiveMaxPool3d = blackbox_module(nn.AdaptiveMaxPool3d)
Embedding = wrap_module(nn.Embedding) AdaptiveAvgPool1d = blackbox_module(nn.AdaptiveAvgPool1d)
EmbeddingBag = wrap_module(nn.EmbeddingBag) AdaptiveAvgPool2d = blackbox_module(nn.AdaptiveAvgPool2d)
RNNBase = wrap_module(nn.RNNBase) AdaptiveAvgPool3d = blackbox_module(nn.AdaptiveAvgPool3d)
RNN = wrap_module(nn.RNN) TripletMarginLoss = blackbox_module(nn.TripletMarginLoss)
LSTM = wrap_module(nn.LSTM) ZeroPad2d = blackbox_module(nn.ZeroPad2d)
GRU = wrap_module(nn.GRU) ConstantPad1d = blackbox_module(nn.ConstantPad1d)
RNNCellBase = wrap_module(nn.RNNCellBase) ConstantPad2d = blackbox_module(nn.ConstantPad2d)
RNNCell = wrap_module(nn.RNNCell) ConstantPad3d = blackbox_module(nn.ConstantPad3d)
LSTMCell = wrap_module(nn.LSTMCell) Bilinear = blackbox_module(nn.Bilinear)
GRUCell = wrap_module(nn.GRUCell) CosineSimilarity = blackbox_module(nn.CosineSimilarity)
PixelShuffle = wrap_module(nn.PixelShuffle) Unfold = blackbox_module(nn.Unfold)
Upsample = wrap_module(nn.Upsample) Fold = blackbox_module(nn.Fold)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d) AdaptiveLogSoftmaxWithLoss = blackbox_module(nn.AdaptiveLogSoftmaxWithLoss)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d) TransformerEncoder = blackbox_module(nn.TransformerEncoder)
PairwiseDistance = wrap_module(nn.PairwiseDistance) TransformerDecoder = blackbox_module(nn.TransformerDecoder)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d) TransformerEncoderLayer = blackbox_module(nn.TransformerEncoderLayer)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d) TransformerDecoderLayer = blackbox_module(nn.TransformerDecoderLayer)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d) Transformer = blackbox_module(nn.Transformer)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d) Flatten = blackbox_module(nn.Flatten)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d) Hardsigmoid = blackbox_module(nn.Hardsigmoid)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
Flatten = wrap_module(nn.Flatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'): if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = wrap_module(nn.Hardswish) Hardswish = blackbox_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'): if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = wrap_module(nn.SiLU) SiLU = blackbox_module(nn.SiLU)
Unflatten = wrap_module(nn.Unflatten) Unflatten = blackbox_module(nn.Unflatten)
TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss) TripletMarginWithDistanceLoss = blackbox_module(nn.TripletMarginWithDistanceLoss)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
\ No newline at end of file
...@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any: ...@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
return None return None
@register_trainer() @register_trainer
class PyTorchImageClassificationTrainer(BaseTrainer): class PyTorchImageClassificationTrainer(BaseTrainer):
""" """
Image classification trainer for PyTorch. Image classification trainer for PyTorch.
...@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer): ...@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently, Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful. only the key ``max_epochs`` is useful.
""" """
super(PyTorchImageClassificationTrainer, self).__init__() super().__init__()
self._use_cuda = torch.cuda.is_available() self._use_cuda = torch.cuda.is_available()
self.model = model self.model = model
if self._use_cuda: if self._use_cuda:
......
import inspect import inspect
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
...@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any: ...@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any:
module = __import__(path, globals(), locals(), [identifier]) module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier) return getattr(module, identifier)
def version_larger_equal(a: str, b: str) -> bool: def version_larger_equal(a: str, b: str) -> bool:
# TODO: refactor later # TODO: refactor later
a = a.split('+')[0] a = a.split('+')[0]
b = b.split('+')[0] b = b.split('+')[0]
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.'))) return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
_records = {} _records = {}
...@@ -29,73 +32,87 @@ def add_record(key, value): ...@@ -29,73 +32,87 @@ def add_record(key, value):
""" """
global _records global _records
if _records is not None: if _records is not None:
#assert key not in _records, '{} already in _records'.format(key) assert key not in _records, '{} already in _records'.format(key)
_records[key] = value _records[key] = value
def _register_module(original_class): def del_record(key):
orig_init = original_class.__init__ global _records
argname_list = list(inspect.signature(original_class).parameters.keys()) if _records is not None:
# Make copy of original __init__, so we can call it without recursion _records.pop(key, None)
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__ def _blackbox_cls(cls, module_name, register_format=None):
class wrapper(cls):
def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls).parameters.keys())
full_args = {}
full_args.update(kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
return original_class for argname, value in zip(argname_list, args):
full_args[argname] = value
# eject un-serializable arguments
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases.
if not isinstance(full_args[k], (int, float, str, dict, list)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
This is not supported. You can ignore this warning if you are passing the model to trainer.')
full_args.pop(k)
def register_module(): if register_format == 'args':
""" add_record(id(self), full_args)
Register a module. elif register_format == 'full':
""" full_class_name = cls.__module__ + '.' + cls.__name__
# use it as a decorator: @register_module() add_record(id(self), {'modulename': full_class_name, 'args': full_args})
def _register(cls):
m = _register_module(
original_class=cls)
return m
return _register super().__init__(*args, **kwargs)
def __del__(self):
del_record(id(self))
def _register_trainer(original_class): # using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
orig_init = original_class.__init__ # instead of simply putting torch.nn or etc.
argname_list = list(inspect.signature(original_class).parameters.keys()) wrapper.__module__ = module_name
# Make copy of original __init__, so we can call it without recursion wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
full_class_name = original_class.__module__ + '.' + original_class.__name__ return wrapper
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
# TODO: support both pytorch and tensorflow
from .nn.pytorch import Module
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__ def blackbox(cls, *args, **kwargs):
"""
To create an blackbox instance inline without decorator. For example,
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')(*args, **kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def blackbox_module(cls):
"""
Register a module. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')
def register_trainer():
def _register(cls):
m = _register_trainer(
original_class=cls)
return m
return _register def register_trainer(cls):
"""
Register a trainer. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'full')
_last_uid = defaultdict(int) _last_uid = defaultdict(int)
......
...@@ -5,6 +5,7 @@ tuner_result.txt ...@@ -5,6 +5,7 @@ tuner_result.txt
assessor_result.txt assessor_result.txt
_generated_model.py _generated_model.py
_generated_model_*.py
data data
generated generated
...@@ -7,9 +7,9 @@ import torch.nn as torch_nn ...@@ -7,9 +7,9 @@ import torch.nn as torch_nn
import ops import ops
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module from nni.retiarii import blackbox_module
@blackbox_module
class AuxiliaryHead(nn.Module): class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """ """ Auxiliary head in 2/3 place of network to let the gradient flow well """
...@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module): ...@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module):
logits = self.linear(out) logits = self.linear(out)
return logits return logits
@register_module()
class Node(nn.Module): class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__() super().__init__()
...@@ -66,7 +65,6 @@ class Node(nn.Module): ...@@ -66,7 +65,6 @@ class Node(nn.Module):
#out = [self.drop_path(o) if o is not None else None for o in out] #out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out) return self.input_switch(out)
@register_module()
class Cell(nn.Module): class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
...@@ -100,7 +98,6 @@ class Cell(nn.Module): ...@@ -100,7 +98,6 @@ class Cell(nn.Module):
output = torch.cat(new_tensors, dim=1) output = torch.cat(new_tensors, dim=1)
return output return output
@register_module()
class CNN(nn.Module): class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
......
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module from nni.retiarii import blackbox_module
@register_module() @blackbox_module
class DropPath(nn.Module): class DropPath(nn.Module):
def __init__(self, p=0.): def __init__(self, p=0.):
""" """
...@@ -12,7 +12,7 @@ class DropPath(nn.Module): ...@@ -12,7 +12,7 @@ class DropPath(nn.Module):
p : float p : float
Probability of an path to be zeroed. Probability of an path to be zeroed.
""" """
super(DropPath, self).__init__() super().__init__()
self.p = p self.p = p
def forward(self, x): def forward(self, x):
...@@ -24,13 +24,13 @@ class DropPath(nn.Module): ...@@ -24,13 +24,13 @@ class DropPath(nn.Module):
return x return x
@register_module() @blackbox_module
class PoolBN(nn.Module): class PoolBN(nn.Module):
""" """
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`. AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
""" """
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super(PoolBN, self).__init__() super().__init__()
if pool_type.lower() == 'max': if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding) self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg': elif pool_type.lower() == 'avg':
...@@ -45,13 +45,13 @@ class PoolBN(nn.Module): ...@@ -45,13 +45,13 @@ class PoolBN(nn.Module):
out = self.bn(out) out = self.bn(out)
return out return out
@register_module() @blackbox_module
class StdConv(nn.Module): class StdConv(nn.Module):
""" """
Standard conv: ReLU - Conv - BN Standard conv: ReLU - Conv - BN
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(StdConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
...@@ -61,13 +61,13 @@ class StdConv(nn.Module): ...@@ -61,13 +61,13 @@ class StdConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class FacConv(nn.Module): class FacConv(nn.Module):
""" """
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
""" """
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super(FacConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
...@@ -78,7 +78,7 @@ class FacConv(nn.Module): ...@@ -78,7 +78,7 @@ class FacConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class DilConv(nn.Module): class DilConv(nn.Module):
""" """
(Dilated) depthwise separable conv. (Dilated) depthwise separable conv.
...@@ -86,7 +86,7 @@ class DilConv(nn.Module): ...@@ -86,7 +86,7 @@ class DilConv(nn.Module):
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field. If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
...@@ -98,14 +98,14 @@ class DilConv(nn.Module): ...@@ -98,14 +98,14 @@ class DilConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class SepConv(nn.Module): class SepConv(nn.Module):
""" """
Depthwise separable conv. Depthwise separable conv.
DilConv(dilation=1) * 2. DilConv(dilation=1) * 2.
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
...@@ -114,13 +114,13 @@ class SepConv(nn.Module): ...@@ -114,13 +114,13 @@ class SepConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class FactorizedReduce(nn.Module): class FactorizedReduce(nn.Module):
""" """
Reduce feature map size by factorized pointwise (stride=2). Reduce feature map size by factorized pointwise (stride=2).
""" """
def __init__(self, C_in, C_out, affine=True): def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__() super().__init__()
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
......
...@@ -13,10 +13,10 @@ from darts_model import CNN ...@@ -13,10 +13,10 @@ from darts_model import CNN
if __name__ == '__main__': if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8) base_model = CNN(32, 3, 16, 10, 8)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10", trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True}, dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32}, dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3}, optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1}) trainer_kwargs={"max_epochs": 1})
#simple_startegy = TPEStrategy() #simple_startegy = TPEStrategy()
simple_startegy = RandomStrategy() simple_startegy = RandomStrategy()
...@@ -31,4 +31,4 @@ if __name__ == '__main__': ...@@ -31,4 +31,4 @@ if __name__ == '__main__':
exp_config.training_service.use_active_gpu = True exp_config.training_service.use_active_gpu = True
exp_config.training_service.gpu_indices = [1, 2] exp_config.training_service.gpu_indices = [1, 2]
exp.run(exp_config, 8081, debug=True) exp.run(exp_config, 8081)
...@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0): ...@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0):
valid_transform = transforms.Compose(normalize) valid_transform = transforms.Compose(normalize)
if cls == "cifar10": if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) dataset_train = CIFAR10(root="./data/cifar10", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) dataset_valid = CIFAR10(root="./data/cifar10", train=False, download=True, transform=valid_transform)
else: else:
raise NotImplementedError raise NotImplementedError
return dataset_train, dataset_valid return dataset_train, dataset_valid
......
from nni.retiarii import blackbox_module
import nni.retiarii.nn.pytorch as nn
import warnings import warnings
import torch import torch
...@@ -8,8 +10,6 @@ import torch.nn.functional as F ...@@ -8,8 +10,6 @@ import torch.nn.functional as F
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2])) sys.path.append(str(Path(__file__).resolve().parents[2]))
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow. # 1.0 - tensorflow.
...@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module): ...@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) + x return self.net(x) + x
class _InvertedResidual(nn.Module): class _InvertedResidual(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, skip, bn_momentum=0.1): def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, skip, bn_momentum=0.1):
...@@ -110,7 +111,7 @@ def _get_depths(depths, alpha): ...@@ -110,7 +111,7 @@ def _get_depths(depths, alpha):
rather than down. """ rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
@register_module()
class MNASNet(nn.Module): class MNASNet(nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model. implements the B1 variant of the model.
...@@ -127,7 +128,7 @@ class MNASNet(nn.Module): ...@@ -127,7 +128,7 @@ class MNASNet(nn.Module):
def __init__(self, alpha, depths, convops, kernel_sizes, num_layers, def __init__(self, alpha, depths, convops, kernel_sizes, num_layers,
skips, num_classes=1000, dropout=0.2): skips, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__() super().__init__()
assert alpha > 0.0 assert alpha > 0.0
assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7 assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7
self.alpha = alpha self.alpha = alpha
...@@ -143,22 +144,22 @@ class MNASNet(nn.Module): ...@@ -143,22 +144,22 @@ class MNASNet(nn.Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
] ]
count = 0 count = 0
#for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \ # for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios): # zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides): for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides):
# TODO: restrict that "choose" can only be used within mutator # TODO: restrict that "choose" can only be used within mutator
ph = nn.Placeholder(label=f'mutable_{count}', related_info={ ph = nn.Placeholder(label=f'mutable_{count}', related_info={
'kernel_size_options': [1, 3, 5], 'kernel_size_options': [1, 3, 5],
'n_layer_options': [1, 2, 3, 4], 'n_layer_options': [1, 2, 3, 4],
'op_type_options': ['__mutated__.base_mnasnet.RegularConv', 'op_type_options': ['__mutated__.base_mnasnet.RegularConv',
'__mutated__.base_mnasnet.DepthwiseConv', '__mutated__.base_mnasnet.DepthwiseConv',
'__mutated__.base_mnasnet.MobileConv'], '__mutated__.base_mnasnet.MobileConv'],
#'se_ratio_options': [0, 0.25], # 'se_ratio_options': [0, 0.25],
'skip_options': ['identity', 'no'], 'skip_options': ['identity', 'no'],
'n_filter_options': [int(filter_size*x) for x in [0.75, 1.0, 1.25]], 'n_filter_options': [int(filter_size*x) for x in [0.75, 1.0, 1.25]],
'exp_ratio': exp_ratio, 'exp_ratio': exp_ratio,
'stride': stride, 'stride': stride,
'in_ch': depths[0] if count == 0 else None 'in_ch': depths[0] if count == 0 else None
}) })
layers.append(ph) layers.append(ph)
'''if conv == "mconv": '''if conv == "mconv":
...@@ -185,7 +186,7 @@ class MNASNet(nn.Module): ...@@ -185,7 +186,7 @@ class MNASNet(nn.Module):
#self.for_test = 10 #self.for_test = 10
def forward(self, x): def forward(self, x):
#if self.for_test == 10: # if self.for_test == 10:
x = self.layers(x) x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions. # Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3]) x = x.mean([2, 3])
...@@ -196,7 +197,7 @@ class MNASNet(nn.Module): ...@@ -196,7 +197,7 @@ class MNASNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
torch_nn.init.kaiming_normal_(m.weight, mode="fan_out", torch_nn.init.kaiming_normal_(m.weight, mode="fan_out",
nonlinearity="relu") nonlinearity="relu")
if m.bias is not None: if m.bias is not None:
torch_nn.init.zeros_(m.bias) torch_nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
...@@ -204,16 +205,18 @@ class MNASNet(nn.Module): ...@@ -204,16 +205,18 @@ class MNASNet(nn.Module):
torch_nn.init.zeros_(m.bias) torch_nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
torch_nn.init.kaiming_uniform_(m.weight, mode="fan_out", torch_nn.init.kaiming_uniform_(m.weight, mode="fan_out",
nonlinearity="sigmoid") nonlinearity="sigmoid")
torch_nn.init.zeros_(m.bias) torch_nn.init.zeros_(m.bias)
def test_model(model): def test_model(model):
model(torch.randn(2, 3, 224, 224)) model(torch.randn(2, 3, 224, 224))
#====================definition of candidate op classes
# ====================definition of candidate op classes
BN_MOMENTUM = 1 - 0.9997 BN_MOMENTUM = 1 - 0.9997
class RegularConv(nn.Module): class RegularConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride): def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__() super().__init__()
...@@ -234,6 +237,7 @@ class RegularConv(nn.Module): ...@@ -234,6 +237,7 @@ class RegularConv(nn.Module):
out = out + x out = out + x
return out return out
class DepthwiseConv(nn.Module): class DepthwiseConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride): def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__() super().__init__()
...@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module): ...@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module):
out = out + x out = out + x
return out return out
class MobileConv(nn.Module): class MobileConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride): def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__() super().__init__()
...@@ -274,7 +279,7 @@ class MobileConv(nn.Module): ...@@ -274,7 +279,7 @@ class MobileConv(nn.Module):
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM), nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# Depthwise # Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding= (kernel_size - 1) // 2, nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=(kernel_size - 1) // 2,
stride=stride, groups=mid_ch, bias=False), stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM), nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
...@@ -288,5 +293,6 @@ class MobileConv(nn.Module): ...@@ -288,5 +293,6 @@ class MobileConv(nn.Module):
out = out + x out = out + x
return out return out
# mnasnet0_5 # mnasnet0_5
ir_module = _InvertedResidual(16, 16, 3, 1, 1, True) ir_module = _InvertedResidual(16, 16, 3, 1, 1, True)
\ No newline at end of file
...@@ -19,12 +19,12 @@ if __name__ == '__main__': ...@@ -19,12 +19,12 @@ if __name__ == '__main__':
_DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1] _DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1]
base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES, base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES,
_DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS) _DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10", trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True}, dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32}, dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3}, optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1}) trainer_kwargs={"max_epochs": 1})
# new interface # new interface
applied_mutators = [] applied_mutators = []
...@@ -41,4 +41,4 @@ if __name__ == '__main__': ...@@ -41,4 +41,4 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10 exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081, debug=True) exp.run(exp_config, 8081)
import random
import nni.retiarii.nn.pytorch as nn
import torch.nn.functional as F
from nni.retiarii.experiment import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.strategies import RandomStrategy
from nni.retiarii.trainer import PyTorchImageClassificationTrainer
class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice([
nn.Linear(4*4*50, hidden_size),
nn.Linear(4*4*50, hidden_size, bias=False)
])
self.fc2 = nn.Linear(hidden_size, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
if __name__ == '__main__':
base_model = Net(128)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="MNIST",
dataset_kwargs={"root": "data/mnist", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
simple_startegy = RandomStrategy()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081 + random.randint(0, 100))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment