Unverified Commit 192a807b authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] refactor based on the new launch approach (#3185)

parent 80394047
......@@ -52,7 +52,7 @@ else:
def connect(self) -> BufferedIOBase:
conn, _ = self._socket.accept()
self.file = conn.makefile('w+b')
self.file = conn.makefile('rwb')
return self.file
def close(self) -> None:
......
......@@ -2,3 +2,4 @@ from .operation import Operation
from .graph import *
from .execution import *
from .mutator import *
from .utils import register_module
\ No newline at end of file
......@@ -15,7 +15,6 @@ def model_to_pytorch_script(model: Model, placement = None) -> str:
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
# FIXME: set correct PATH for the packages (after launch refactor)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
......@@ -71,7 +70,7 @@ def _remove_prefix(names, graph_name):
return names[len(graph_name):] if names.startswith(graph_name) else names
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str:
nodes = graph.topo_sort() # FIXME: topological sort is needed here
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
......@@ -130,10 +129,6 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# FIXME: remove these two lines
import sys
sys.path.append("test/convert_test")
{}
{}
......
import json_tricks
import logging
import re
import torch
......@@ -6,9 +7,10 @@ from ..graph import Graph, Node, Edge, Model
from ..operation import Cell, Operation
from ..nn.pytorch import Placeholder, LayerChoice, InputChoice
from .op_types import MODULE_EXCEPT_LIST, Type
from .op_types import MODULE_EXCEPT_LIST, OpTypeName, BasicOpsPT
from .utils import build_full_name, _convert_name
_logger = logging.getLogger(__name__)
global_seq = 0
global_graph_id = 0
......@@ -80,7 +82,7 @@ def create_prim_constant_node(ir_graph, node, module_name):
if node.outputsAt(0).toIValue() is not None:
attrs = {'value': node.outputsAt(0).toIValue()}
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.Constant, global_seq),
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, global_seq),
node.kind(), attrs)
return new_node
......@@ -163,6 +165,33 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# key: tensor (%out.1), value: node (this node)
output_remap = {}
def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition')
expr = _generate_expr(cond_tensor)
return eval(expr)
def handle_if_node(node):
"""
Parameters
......@@ -179,19 +208,13 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
if not inputs[0].node().kind() in ['prim::Constant', 'prim::GetAttr']:
raise RuntimeError('"if" whose condition is not constant or attribute has not been supported yet!')
chosen_block = None
if inputs[0].node().kind() == 'prim::Constant':
chosen_block = 0 if inputs[0].toIValue() else 1
if inputs[0].node().kind() == 'prim::GetAttr':
chosen_block = 0 if getattr(module, inputs[0].node().s('name')) else 1
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
assert last_block_node is not None
return last_block_node
def handle_single_node(node):
......@@ -287,29 +310,33 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.ListConstruct, global_seq), node.kind())
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, global_seq), node.kind())
node_index[node] = new_node
_add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, Type.BasicOpsPT[node.kind()], global_seq), node.kind())
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, Type.BasicOpsPT[node.kind()], global_seq), node.kind())
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = handle_prim_attr_node(node)
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.Attr, global_seq),
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::min':
print('zql: ', sm_graph)
exit(1)
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
raise RuntimeError('Loop has not been supported yet!')
......@@ -343,7 +370,10 @@ def merge_aten_slices(ir_graph):
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), Type.MergedSlice)
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
......@@ -383,10 +413,13 @@ def _handle_layerchoice(module):
m_attrs = {}
candidates = module.candidate_ops
choices = []
for i, cand in enumerate(candidates):
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict)
m_attrs[f'choice_{i}'] = modules_arg[id(cand)]
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': modules_arg[id(cand)]})
m_attrs[f'choices'] = choices
m_attrs['label'] = module.label
return m_attrs
......@@ -425,17 +458,18 @@ def convert_module(script_module, module, module_name, ir_model):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
if original_type_name == Type.LayerChoice:
if original_type_name == OpTypeName.LayerChoice:
m_attrs = _handle_layerchoice(module)
return None, m_attrs
if original_type_name == Type.InputChoice:
if original_type_name == OpTypeName.InputChoice:
m_attrs = _handle_inputchoice(module)
return None, m_attrs
if original_type_name in Type.Placeholder:
if original_type_name == OpTypeName.Placeholder:
m_attrs = modules_arg[id(module)]
return None, m_attrs
if original_type_name in torch.nn.__dict__ and original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
assert id(module) in modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = modules_arg[id(module)]
return None, m_attrs
......@@ -463,7 +497,15 @@ def convert_module(script_module, module, module_name, ir_model):
ir_graph._register()
return ir_graph, modules_arg[id(module)]
if id(module) not in modules_arg:
raise RuntimeError(f'{original_type_name} arguments are not recorded, \
you might have forgotten to decorate this class with @register_module()')
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return ir_graph, {}
def convert_to_graph(script_module, module, recorded_modules_arg):
"""
......
from enum import Enum
MODULE_EXCEPT_LIST = ['Sequential']
class Type:
"""Node Type class
class OpTypeName(str, Enum):
"""
op type to its type name str
"""
Attr = 'Attr'
Constant = 'Constant'
......@@ -11,21 +14,22 @@ class Type:
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View'
}
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View',
'aten::eq': 'Eq',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF = {}
\ No newline at end of file
BasicOpsTF = {}
\ No newline at end of file
......@@ -36,13 +36,6 @@ def get_and_register_default_listener(engine: AbstractExecutionEngine) -> Defaul
engine.register_graph_listener(_default_listener)
return _default_listener
def _get_search_space() -> 'Dict':
engine = get_execution_engine()
while True:
time.sleep(1)
if engine.get_search_space() is not None:
break
return engine.get_search_space()
def submit_models(*models: Model) -> None:
engine = get_execution_engine()
......
......@@ -50,10 +50,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self._running_models: Dict[int, Model] = dict()
def get_search_space(self) -> 'JSON':
advisor = get_advisor()
return advisor.search_space
def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model),
......@@ -106,7 +102,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
# FIXME: update this part to dump code to a correct path!!!
with open('_generated_model.py', 'w') as f:
f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module)
......
import dataclasses
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Thread
from typing import Any, List, Optional
from ..experiment import Experiment, TrainingServiceConfig
from ..experiment import launcher, rest
from ..experiment.config.base import ConfigBase, PathLike
from ..experiment.config import util
from .utils import get_records
from .integration import RetiariiAdvisor
from .converter.graph_gen import convert_to_graph
from .mutator import LayerChoiceMutator, InputChoiceMutator
_logger = logging.getLogger(__name__)
@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove
trial_command: str = 'python3 -m nni.retiarii.trial_entry'
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
nni_manager_ip: Optional[str] = None
debug: bool = False
log_level: Optional[str] = None
experiment_working_directory: Optional[PathLike] = None
# remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(training_service_platform)
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
@property
def _canonical_rules(self):
return _canonical_rules
@property
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
'experiment_working_directory': util.canonical_path
}
_validation_rules = {
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
'trial_concurrency': lambda value: value > 0,
'trial_gpu_number': lambda value: value >= 0,
'max_experiment_duration': lambda value: util.parse_time(value) > 0,
'max_trial_number': lambda value: value > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
class RetiariiExperiment(Experiment):
def __init__(self, base_model: 'nn.Module', trainer: 'BaseTrainer',
applied_mutators: List['Mutator'], strategy: 'BaseStrategy'):
self.config: RetiariiExeConfig = None
self.port: Optional[int] = None
self.base_model = base_model
self.trainer = trainer
self.applied_mutators = applied_mutators
self.strategy = strategy
self.recorded_module_args = get_records()
self._dispatcher = RetiariiAdvisor()
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
def _process_inline_mutation(self, base_model):
"""
the mutators are order independent
"""
lc_nodes = base_model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.nn.LayerChoice')
ic_nodes = base_model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.nn.InputChoice')
if not lc_nodes and not ic_nodes:
return None
applied_mutators = []
for node in lc_nodes:
mutator = LayerChoiceMutator(node.name, node.operation.parameters['choices'])
applied_mutators.append(mutator)
for node in ic_nodes:
mutator = InputChoiceMutator(node.name, node.operation.parameters['n_chosen'])
applied_mutators.append(mutator)
return applied_mutators
def _start_strategy(self):
import torch
try:
script_module = torch.jit.script(self.base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args)
assert id(self.trainer) in self.recorded_module_args
trainer_config = self.recorded_module_args[id(self.trainer)]
base_model.apply_trainer(trainer_config['modulename'], trainer_config['args'])
# handle inline mutations
mutators = self._process_inline_mutation(base_model)
if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \
do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
self.applied_mutators = mutators
_logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators)).start()
_logger.info('Strategy started!')
def start(self, config: RetiariiExeConfig, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
# FIXME:
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start()
self._start_strategy()
# TODO: register experiment management metadata
def stop(self) -> None:
"""
Stop background experiment.
"""
self._proc.kill()
self._pipe.close()
self.port = None
self._proc = None
self._pipe = None
def run(self, config: RetiariiExeConfig, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self.config = config
self.start(config, port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
# TODO: double check the status
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']:
return status
finally:
self.stop()
def get_status(self) -> str:
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
\ No newline at end of file
......@@ -169,10 +169,29 @@ class Model:
matched_nodes.extend(nodes)
return matched_nodes
def get_by_name(self, name: str) -> Union['Graph', 'Node']:
def get_nodes_by_type(self, type_name: str) -> List['Node']:
"""
Find the graph or node that have the given name space name.
Traverse all the nodes to find the matched node(s) with the given type.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_type(type_name)
matched_nodes.extend(nodes)
return matched_nodes
def get_node_by_name(self, node_name: str) -> 'Node':
"""
Traverse all the nodes to find the matched node with the given name.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_name(node_name)
matched_nodes.extend(nodes)
assert len(matched_nodes) <= 1
if matched_nodes:
return matched_nodes[0]
else:
return None
class ModelStatus(Enum):
......@@ -326,6 +345,9 @@ class Graph:
def get_nodes_by_label(self, label: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.label == label]
def get_nodes_by_name(self, name: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.name == name]
def topo_sort(self) -> List['Node']:
node_to_fanin = {}
curr_nodes = []
......
......@@ -101,4 +101,32 @@ class _RecorderSampler(Sampler):
def choice(self, candidates: List[Choice], *args) -> Choice:
self.recorded_candidates.append(candidates)
return candidates[0]
\ No newline at end of file
return candidates[0]
# the following is for inline mutation
class LayerChoiceMutator(Mutator):
def __init__(self, node_name: str, candidates: List):
super().__init__()
self.node_name = node_name
self.candidates = candidates
def mutate(self, model):
target = model.get_node_by_name(self.node_name)
indexes = [i for i in range(len(self.candidates))]
chosen_index = self.choice(indexes)
chosen_cand = self.candidates[chosen_index]
target.update_operation(chosen_cand['type'], chosen_cand['parameters'])
class InputChoiceMutator(Mutator):
def __init__(self, node_name: str, n_chosen: int):
super().__init__()
self.node_name = node_name
self.n_chosen = n_chosen
def mutate(self, model):
target = model.get_node_by_name(self.node_name)
candidates = [i for i in range(self.n_chosen)]
chosen = self.choice(candidates)
target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs',
{'chosen': chosen})
......@@ -4,47 +4,58 @@ import torch
import torch.nn as nn
from typing import (Any, Tuple, List, Optional)
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
_records = None
def enable_record_args():
global _records
_records = {}
_logger.info('args recording enabled')
from ...utils import add_record
def disable_record_args():
global _records
_records = None
_logger.info('args recording disabled')
def get_records():
global _records
return _records
_logger = logging.getLogger(__name__)
def add_record(name, value):
global _records
if _records is not None:
assert name not in _records, '{} already in _records'.format(name)
_records[name] = value
__all__ = [
'LayerChoice', 'InputChoice', 'Placeholder',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
'Tanhshrink', 'RReLU', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
'LSTMCell', 'GRUCell', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten', 'Hardsigmoid', 'Hardswish'
]
class LayerChoice(nn.Module):
def __init__(self, candidate_ops: List, label: str = None):
def __init__(self, op_candidates, reduction=None, return_mask=False, key=None):
super(LayerChoice, self).__init__()
self.candidate_ops = candidate_ops
self.label = label
self.candidate_ops = op_candidates
self.label = key
if reduction or return_mask:
_logger.warning('input arguments `reduction` and `return_mask` are deprecated!')
def forward(self, x):
return x
class InputChoice(nn.Module):
def __init__(self, n_chosen: int = 1, reduction: str = 'sum', label: str = None):
def __init__(self, n_candidates=None, choose_from=None, n_chosen=1,
reduction="sum", return_mask=False, key=None):
super(InputChoice, self).__init__()
self.n_chosen = n_chosen
self.reduction = reduction
self.label = label
self.label = key
if n_candidates or choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
def forward(self, candidate_inputs: List['Tensor']) -> 'Tensor':
# fake return
......@@ -62,9 +73,7 @@ class ValueChoice:
class Placeholder(nn.Module):
def __init__(self, label, related_info):
global _records
if _records is not None:
_records[id(self)] = related_info
add_record(id(self), related_info)
self.label = label
self.related_info = related_info
super(Placeholder, self).__init__()
......@@ -72,34 +81,29 @@ class Placeholder(nn.Module):
def forward(self, x):
return x
class ChosenInputs(nn.Module):
def __init__(self, chosen: int):
super().__init__()
self.chosen = chosen
def forward(self, candidate_inputs):
# TODO: support multiple chosen inputs
return candidate_inputs[self.chosen]
# the following are pytorch modules
class Module(nn.Module):
def __init__(self, *args, **kwargs):
# TODO: users have to pass init's arguments to super init's arguments
global _records
if _records is not None:
assert not kwargs
argname_list = list(inspect.signature(self.__class__).parameters.keys())
assert len(argname_list) == len(args), 'Error: {} not put input arguments in its super().__init__ function'.format(self.__class__)
full_args = {}
for i, arg_value in enumerate(args):
full_args[argname_list[i]] = args[i]
_records[id(self)] = full_args
#print('my module: ', id(self), args, kwargs)
def __init__(self):
super(Module, self).__init__()
class Sequential(nn.Sequential):
def __init__(self, *args):
global _records
if _records is not None:
_records[id(self)] = {} # no args need to be recorded
add_record(id(self), {})
super(Sequential, self).__init__(*args)
class ModuleList(nn.ModuleList):
def __init__(self, *args):
global _records
if _records is not None:
_records[id(self)] = {} # no args need to be recorded
add_record(id(self), {})
super(ModuleList, self).__init__(*args)
def wrap_module(original_class):
......@@ -108,24 +112,132 @@ def wrap_module(original_class):
# Make copy of original __init__, so we can call it without recursion
def __init__(self, *args, **kws):
global _records
if _records is not None:
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i]
_records[id(self)] = full_args
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i]
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
# TODO: support different versions of pytorch
Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear)
Conv1d = wrap_module(nn.Conv1d)
Conv2d = wrap_module(nn.Conv2d)
BatchNorm2d = wrap_module(nn.BatchNorm2d)
Conv3d = wrap_module(nn.Conv3d)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d)
Threshold = wrap_module(nn.Threshold)
ReLU = wrap_module(nn.ReLU)
Dropout = wrap_module(nn.Dropout)
Linear = wrap_module(nn.Linear)
MaxPool2d = wrap_module(nn.MaxPool2d)
Hardtanh = wrap_module(nn.Hardtanh)
ReLU6 = wrap_module(nn.ReLU6)
Sigmoid = wrap_module(nn.Sigmoid)
Tanh = wrap_module(nn.Tanh)
Softmax = wrap_module(nn.Softmax)
Softmax2d = wrap_module(nn.Softmax2d)
LogSoftmax = wrap_module(nn.LogSoftmax)
ELU = wrap_module(nn.ELU)
SELU = wrap_module(nn.SELU)
CELU = wrap_module(nn.CELU)
GLU = wrap_module(nn.GLU)
GELU = wrap_module(nn.GELU)
Hardshrink = wrap_module(nn.Hardshrink)
LeakyReLU = wrap_module(nn.LeakyReLU)
LogSigmoid = wrap_module(nn.LogSigmoid)
Softplus = wrap_module(nn.Softplus)
Softshrink = wrap_module(nn.Softshrink)
MultiheadAttention = wrap_module(nn.MultiheadAttention)
PReLU = wrap_module(nn.PReLU)
Softsign = wrap_module(nn.Softsign)
Softmin = wrap_module(nn.Softmin)
Tanhshrink = wrap_module(nn.Tanhshrink)
RReLU = wrap_module(nn.RReLU)
AvgPool1d = wrap_module(nn.AvgPool1d)
AvgPool2d = wrap_module(nn.AvgPool2d)
Identity = wrap_module(nn.Identity)
AvgPool3d = wrap_module(nn.AvgPool3d)
MaxPool1d = wrap_module(nn.MaxPool1d)
MaxPool2d = wrap_module(nn.MaxPool2d)
MaxPool3d = wrap_module(nn.MaxPool3d)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d)
LPPool1d = wrap_module(nn.LPPool1d)
LPPool2d = wrap_module(nn.LPPool2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm)
BatchNorm1d = wrap_module(nn.BatchNorm1d)
BatchNorm2d = wrap_module(nn.BatchNorm2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d)
LayerNorm = wrap_module(nn.LayerNorm)
GroupNorm = wrap_module(nn.GroupNorm)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm)
Dropout = wrap_module(nn.Dropout)
Dropout2d = wrap_module(nn.Dropout2d)
Dropout3d = wrap_module(nn.Dropout3d)
AlphaDropout = wrap_module(nn.AlphaDropout)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d)
Embedding = wrap_module(nn.Embedding)
EmbeddingBag = wrap_module(nn.EmbeddingBag)
RNNBase = wrap_module(nn.RNNBase)
RNN = wrap_module(nn.RNN)
LSTM = wrap_module(nn.LSTM)
GRU = wrap_module(nn.GRU)
RNNCellBase = wrap_module(nn.RNNCellBase)
RNNCell = wrap_module(nn.RNNCell)
LSTMCell = wrap_module(nn.LSTMCell)
GRUCell = wrap_module(nn.GRUCell)
PixelShuffle = wrap_module(nn.PixelShuffle)
Upsample = wrap_module(nn.Upsample)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d)
PairwiseDistance = wrap_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
Flatten = wrap_module(nn.Flatten)
#Unflatten = wrap_module(nn.Unflatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
Hardswish = wrap_module(nn.Hardswish)
#SiLU = wrap_module(nn.SiLU)
#TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
......@@ -105,7 +105,7 @@ class PyTorchOperation(Operation):
return None
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
from .converter.op_types import Type
from .converter.op_types import OpTypeName
if self._to_class_name() is not None:
return f'{output} = self.{field}({", ".join(inputs)})'
elif self.type.startswith('Function.'):
......@@ -133,7 +133,7 @@ class PyTorchOperation(Operation):
elif self.type == 'aten::add':
assert len(inputs) == 2
return f'{output} = {inputs[0]} + {inputs[1]}'
elif self.type == Type.MergedSlice:
elif self.type == OpTypeName.MergedSlice:
assert (len(inputs) - 1) % 4 == 0
slices = []
dim = int((len(inputs) - 1) / 4)
......
......@@ -4,5 +4,5 @@ from typing import List
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def run(self, base_model: 'Model', applied_mutators: List['Mutator'], trainer: 'BaseTrainer') -> None:
def run(self, base_model: 'Model', applied_mutators: List['Mutator']) -> None:
pass
......@@ -42,7 +42,7 @@ class TPEStrategy(BaseStrategy):
self.tpe_sampler = TPESampler()
self.model_id = 0
def run(self, base_model, applied_mutators, trainer):
def run(self, base_model, applied_mutators):
sample_space = []
new_model = base_model
for mutator in applied_mutators:
......@@ -61,9 +61,6 @@ class TPEStrategy(BaseStrategy):
_logger.info('mutate model...')
mutator.bind_sampler(self.tpe_sampler)
model = mutator.apply(model)
# get and apply training approach
_logger.info('apply training approach...')
model.apply_trainer(trainer['modulename'], trainer['args'])
# run models
submit_models(model)
wait_models(model)
......
import abc
import inspect
from ..nn.pytorch import add_record
from typing import *
......@@ -19,23 +18,6 @@ class BaseTrainer(abc.ABC):
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
"""
def __init__(self, *args, **kwargs):
module = self.__class__.__module__
if module is None or module == str.__class__.__module__:
full_class_name = self.__class__.__name__
else:
full_class_name = module + '.' + self.__class__.__name__
assert not kwargs
argname_list = list(inspect.signature(self.__class__).parameters.keys())
assert len(argname_list) == len(args), 'Error: {} not put input arguments in its super().__init__ function'.format(self.__class__)
full_args = {}
for i, arg_value in enumerate(args):
if argname_list[i] == 'model':
assert i == 0
continue
full_args[argname_list[i]] = args[i]
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
@abc.abstractmethod
def fit(self) -> None:
......
......@@ -10,6 +10,7 @@ from torchvision import datasets, transforms
import nni
from ..interface import BaseTrainer
from ...utils import register_trainer
def get_default_transform(dataset: str) -> Any:
......@@ -41,7 +42,7 @@ def get_default_transform(dataset: str) -> Any:
# unsupported dataset, return None
return None
@register_trainer()
class PyTorchImageClassificationTrainer(BaseTrainer):
"""
Image classification trainer for PyTorch.
......@@ -78,9 +79,6 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super(PyTorchImageClassificationTrainer, self).__init__(model,
dataset_cls, dataset_kwargs, dataloader_kwargs,
optimizer_cls, optimizer_kwargs, trainer_kwargs)
self._use_cuda = torch.cuda.is_available()
self.model = model
if self._use_cuda:
......
import traceback
from .nn.pytorch import enable_record_args, get_records, disable_record_args
import inspect
def import_(target: str, allow_none: bool = False) -> 'Any':
if target is None:
......@@ -8,17 +7,79 @@ def import_(target: str, allow_none: bool = False) -> 'Any':
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
class TraceClassArguments:
def __init__(self):
self.recorded_arguments = None
def __enter__(self):
enable_record_args()
return self
def __exit__(self, exc_type, exc_value, tb):
if exc_type is not None:
traceback.print_exception(exc_type, exc_value, tb)
# return False # uncomment to pass exception through
self.recorded_arguments = get_records()
disable_record_args()
_records = {}
def get_records():
global _records
return _records
def add_record(key, value):
"""
"""
global _records
if _records is not None:
assert key not in _records, '{} already in _records'.format(key)
_records[key] = value
def _register_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i]
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def register_module():
"""
Register a module.
"""
# use it as a decorator: @register_module()
def _register(cls):
m = _register_module(
original_class=cls)
return m
return _register
def _register_trainer(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
full_class_name = original_class.__module__ + '.' + original_class.__name__
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
# TODO: support both pytorch and tensorflow
from .nn.pytorch import Module
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = args[i]
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def register_trainer():
def _register(cls):
m = _register_trainer(
original_class=cls)
return m
return _register
\ No newline at end of file
from collections import OrderedDict
from typing import (List, Optional)
import torch
import torch.nn as torch_nn
#sys.path.append(str(Path(__file__).resolve().parents[2]))
import ops
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
def __init__(self, input_size, C, n_classes):
""" assuming input size 7x7 or 8x8 """
assert input_size in [7, 8]
super().__init__()
self.net = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out
nn.Conv2d(C, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.linear = nn.Linear(768, n_classes)
def forward(self, x):
out = self.net(x)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
return logits
@register_module()
class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
nn.LayerChoice([
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
]))
self.drop_path = ops.DropPath()
self.input_switch = nn.InputChoice(n_chosen=2)
def forward(self, prev_nodes: List['Tensor']) -> 'Tensor':
#assert self.ops.__len__() == len(prev_nodes)
#out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = []
for i, op in enumerate(self.ops):
out.append(op(prev_nodes[i]))
#out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)
@register_module()
class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
new_tensors = []
for node in self.mutable_ops:
tmp = tensors + new_tensors
cur_tensor = node(tmp)
new_tensors.append(cur_tensor)
output = torch.cat(new_tensors, dim=1)
return output
@register_module()
class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3, auxiliary=False):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.n_classes = n_classes
self.n_layers = n_layers
self.aux_pos = 2 * n_layers // 3 if auxiliary else -1
c_cur = stem_multiplier * self.channels
self.stem = nn.Sequential(
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
nn.BatchNorm2d(c_cur)
)
# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp, channels_p, c_cur = c_cur, c_cur, channels
self.cells = nn.ModuleList()
reduction_p, reduction = False, False
for i in range(n_layers):
reduction_p, reduction = reduction, False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if i in [n_layers // 3, 2 * n_layers // 3]:
c_cur *= 2
reduction = True
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out
#if i == self.aux_pos:
# self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(channels_p, n_classes)
def forward(self, x):
s0 = s1 = self.stem(x)
#aux_logits = None
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1)
#if i == self.aux_pos and self.training:
# aux_logits = self.aux_head(s1)
out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
#if aux_logits is not None:
# return logits, aux_logits
return logits
def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath):
module.p = p
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
@register_module()
class DropPath(nn.Module):
def __init__(self, p=0.):
"""
Drop path with probability.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super(DropPath, self).__init__()
self.p = p
def forward(self, x):
if self.training and self.p > 0.:
keep_prob = 1. - self.p
# per data point mask
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
return x / keep_prob * mask
return x
@register_module()
class PoolBN(nn.Module):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super(PoolBN, self).__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C, affine=affine)
def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out
@register_module()
class StdConv(nn.Module):
"""
Standard conv: ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(StdConv, self).__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@register_module()
class FacConv(nn.Module):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super(FacConv, self).__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@register_module()
class DilConv(nn.Module):
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@register_module()
class SepConv(nn.Module):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
)
def forward(self, x):
return self.net(x)
@register_module()
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
import json
import os
import sys
import torch
from pathlib import Path
from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy
from nni.retiarii.trainer import PyTorchImageClassificationTrainer
from darts_model import CNN
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
simple_startegy = TPEStrategy()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'darts_search'
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = True
exp_config.training_service.gpu_indices = [1, 2]
exp.run(exp_config, 8081, debug=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment