"...composable_kernel_rocm.git" did not exist on "7d45045c0036ae4dcce4cf5968355991956c7091"
Unverified Commit ae50ed14 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Refactor wrap module as "blackbox_module" (#3238)

parent 15da19d3
...@@ -2,4 +2,4 @@ from .operation import Operation ...@@ -2,4 +2,4 @@ from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
from .mutator import * from .mutator import *
from .utils import register_module from .utils import blackbox, blackbox_module, register_trainer
\ No newline at end of file
...@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str: ...@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
def _sorted_incoming_edges(node: Node) -> List[Edge]: def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node] edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: %s', str(edges)) _logger.debug('sorted_incoming_edges: %s', str(edges))
if not edges: if not edges:
return [] return []
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges])) _logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges): if all(edge.tail_slot is None for edge in edges):
return edges return edges
if all(isinstance(edge.tail_slot, int) for edge in edges): if all(isinstance(edge.tail_slot, int) for edge in edges):
......
...@@ -6,17 +6,20 @@ import torch ...@@ -6,17 +6,20 @@ import torch
from ..graph import Graph, Model, Node from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell from ..operation import Cell
from ..utils import get_records
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .utils import _convert_name, build_full_name from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
global_seq = 0
global_graph_id = 0
modules_arg = None
class GraphConverter:
def __init__(self):
self.global_seq = 0
self.global_graph_id = 0
self.modules_arg = get_records()
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False): def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
""" """
Parameters Parameters
---------- ----------
...@@ -76,29 +79,24 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ...@@ -76,29 +79,24 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap,
new_node_input_idx += 1 new_node_input_idx += 1
def create_prim_constant_node(self, ir_graph, node, module_name):
def create_prim_constant_node(ir_graph, node, module_name):
global global_seq
attrs = {} attrs = {}
if node.outputsAt(0).toIValue() is not None: if node.outputsAt(0).toIValue() is not None:
attrs = {'value': node.outputsAt(0).toIValue()} attrs = {'value': node.outputsAt(0).toIValue()}
global_seq += 1 self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, global_seq), new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, self.global_seq),
node.kind(), attrs) node.kind(), attrs)
return new_node return new_node
def handle_prim_attr_node(self, node):
def handle_prim_attr_node(node):
assert node.hasAttribute('name') assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()} attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs return node.kind(), attrs
def _remove_mangle(self, module_type_str):
def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str) return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(self, ir_graph, targeted_type=None):
def remove_unconnected_nodes(ir_graph, targeted_type=None):
""" """
Parameters Parameters
---------- ----------
...@@ -126,8 +124,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None): ...@@ -126,8 +124,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None):
for hidden_node in to_removes: for hidden_node in to_removes:
hidden_node.remove() hidden_node.remove()
def handle_graph_nodes(self, script_module, sm_graph, module, module_name, ir_model, ir_graph):
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
""" """
Convert torch script node to our node ir, and build our graph ir Convert torch script node to our node ir, and build our graph ir
...@@ -234,13 +231,12 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -234,13 +231,12 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
Node Node
the created node ir the created node ir
""" """
global global_seq
if node.kind() == 'prim::CallMethod': if node.kind() == 'prim::CallMethod':
# get and handle the first input, which should be an nn.Module # get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name') assert node.hasAttribute('name')
if node.s('name') == 'forward': if node.s('name') == 'forward':
# node.inputsAt(0).type() is <class 'torch._C.ClassType'> # node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = _remove_mangle(node.inputsAt(0).type().str()) submodule_type_str = self._remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node() submodule = node.inputsAt(0).node()
assert submodule.kind() == 'prim::GetAttr' assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name') assert submodule.hasAttribute('name')
...@@ -258,7 +254,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -258,7 +254,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
submodule_full_name = build_full_name(module_name, submodule_name) submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name) submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name], subgraph, sub_m_attrs = self.convert_module(script_module._modules[submodule_name],
submodule_obj, submodule_obj,
submodule_full_name, ir_model) submodule_full_name, ir_model)
else: else:
...@@ -276,7 +272,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -276,7 +272,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name]) submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name])
predecessor_obj = getattr(module, predecessor_name) predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name) submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name], subgraph, sub_m_attrs = self.convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model) submodule_obj, submodule_full_name, ir_model)
else: else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str())) raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
...@@ -296,45 +292,45 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -296,45 +292,45 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
subcell = ir_graph.add_node(submodule_full_name, new_cell) subcell = ir_graph.add_node(submodule_full_name, new_cell)
node_index[node] = subcell node_index[node] = subcell
# connect the cell into graph # connect the cell into graph
_add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True) self._add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else: else:
raise RuntimeError('unsupported CallMethod {}'.format(node.s('name'))) raise RuntimeError('unsupported CallMethod {}'.format(node.s('name')))
elif node.kind() == 'prim::CallFunction': elif node.kind() == 'prim::CallFunction':
func_type_str = _remove_mangle(node.inputsAt(0).type().str()) func_type_str = self._remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node() func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant' assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name') assert func.hasAttribute('name')
func_name = func.s('name') func_name = func.s('name')
# create node for func # create node for func
global_seq += 1 self.global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, global_seq), func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
'{}.{}'.format(func_type_str, func_name)) '{}.{}'.format(func_type_str, func_name))
node_index[node] = func_node node_index[node] = func_node
_add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True) self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant': elif node.kind() == 'prim::Constant':
new_node = create_prim_constant_node(ir_graph, node, module_name) new_node = self.create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct': elif node.kind() == 'prim::ListConstruct':
global_seq += 1 self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, global_seq), node.kind()) new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, self.global_seq), node.kind())
node_index[node] = new_node node_index[node] = new_node
_add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap) self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append': elif node.kind() == 'aten::append':
global_seq += 1 self.global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind()) aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
node_index[node] = aten_node node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap) self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'): elif node.kind().startswith('aten::'):
# handle aten::XXX # handle aten::XXX
global_seq += 1 self.global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind()) aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
node_index[node] = aten_node node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap) self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr': elif node.kind() == 'prim::GetAttr':
node_type, attrs = handle_prim_attr_node(node) node_type, attrs = self.handle_prim_attr_node(node)
global_seq += 1 self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq), new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, self.global_seq),
node_type, attrs) node_type, attrs)
node_index[node] = new_node node_index[node] = new_node
elif node.kind() == 'prim::If': elif node.kind() == 'prim::If':
...@@ -354,8 +350,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -354,8 +350,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
return node_index return node_index
def merge_aten_slices(self, ir_graph):
def merge_aten_slices(ir_graph):
""" """
if there is aten::slice node, merge the consecutive ones together. if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script, ```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
...@@ -401,36 +396,31 @@ def merge_aten_slices(ir_graph): ...@@ -401,36 +396,31 @@ def merge_aten_slices(ir_graph):
edge.head = new_slice_node edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node) ir_graph.hidden_nodes.remove(node)
def refine_graph(self, ir_graph):
def refine_graph(ir_graph):
""" """
Do the following process to simplify graph: Do the following process to simplify graph:
1. remove unconnected constant node 1. remove unconnected constant node
2. remove unconnected getattr node 2. remove unconnected getattr node
""" """
# some constant is not used, for example, function name as prim::Constant # some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant') self.remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr') self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph) self.merge_aten_slices(ir_graph)
def _handle_layerchoice(module):
global modules_arg
def _handle_layerchoice(self, module):
m_attrs = {} m_attrs = {}
candidates = module.candidate_ops candidates = module.op_candidates
choices = [] choices = []
for cand in candidates: for cand in candidates:
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand)) assert id(cand) in self.modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict) assert isinstance(self.modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__ cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': modules_arg[id(cand)]}) choices.append({'type': cand_type, 'parameters': self.modules_arg[id(cand)]})
m_attrs[f'choices'] = choices m_attrs[f'choices'] = choices
m_attrs['label'] = module.label m_attrs['label'] = module.label
return m_attrs return m_attrs
def _handle_inputchoice(self, module):
def _handle_inputchoice(module):
m_attrs = {} m_attrs = {}
m_attrs['n_candidates'] = module.n_candidates m_attrs['n_candidates'] = module.n_candidates
m_attrs['n_chosen'] = module.n_chosen m_attrs['n_chosen'] = module.n_chosen
...@@ -438,8 +428,7 @@ def _handle_inputchoice(module): ...@@ -438,8 +428,7 @@ def _handle_inputchoice(module):
m_attrs['label'] = module.label m_attrs['label'] = module.label
return m_attrs return m_attrs
def convert_module(self, script_module, module, module_name, ir_model):
def convert_module(script_module, module, module_name, ir_model):
""" """
Convert a module to its graph ir (i.e., Graph) along with its input arguments Convert a module to its graph ir (i.e., Graph) along with its input arguments
...@@ -461,34 +450,36 @@ def convert_module(script_module, module, module_name, ir_model): ...@@ -461,34 +450,36 @@ def convert_module(script_module, module, module_name, ir_model):
dict dict
the input arguments of this module the input arguments of this module
""" """
global global_graph_id
global modules_arg
# NOTE: have not supported nested LayerChoice, i.e., a candidate module # NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice # also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name original_type_name = script_module.original_name
if original_type_name == OpTypeName.LayerChoice: m_attrs = None
m_attrs = _handle_layerchoice(module) if original_type_name in MODULE_EXCEPT_LIST:
return None, m_attrs pass # do nothing
if original_type_name == OpTypeName.InputChoice: elif original_type_name == OpTypeName.LayerChoice:
m_attrs = _handle_inputchoice(module) m_attrs = self._handle_layerchoice(module)
return None, m_attrs elif original_type_name == OpTypeName.InputChoice:
if original_type_name == OpTypeName.Placeholder: m_attrs = self._handle_inputchoice(module)
m_attrs = modules_arg[id(module)] elif original_type_name == OpTypeName.Placeholder:
return None, m_attrs m_attrs = self.modules_arg[id(module)]
if original_type_name in torch.nn.__dict__ and original_type_name not in MODULE_EXCEPT_LIST: elif original_type_name in torch.nn.__dict__:
# this is a basic module from pytorch, no need to parse its graph # this is a basic module from pytorch, no need to parse its graph
assert id(module) in modules_arg, f'{original_type_name} arguments are not recorded' assert id(module) in self.modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = modules_arg[id(module)] m_attrs = self.modules_arg[id(module)]
elif id(module) in self.modules_arg:
# this module is marked as blackbox, won't continue to parse
m_attrs = self.modules_arg[id(module)]
if m_attrs is not None:
return None, m_attrs return None, m_attrs
# handle TorchScript graph # handle TorchScript graph
sm_graph = script_module.graph sm_graph = script_module.graph
global_graph_id += 1 self.global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=global_graph_id, name=module_name, _internal=True) ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
# handle graph nodes # handle graph nodes
node_index = handle_graph_nodes(script_module, sm_graph, module, node_index = self.handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph) module_name, ir_model, ir_graph)
# handle graph outputs # handle graph outputs
...@@ -502,22 +493,14 @@ def convert_module(script_module, module, module_name, ir_model): ...@@ -502,22 +493,14 @@ def convert_module(script_module, module, module_name, ir_model):
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx), ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None)) tail=(ir_graph.output_node, None))
refine_graph(ir_graph) self.refine_graph(ir_graph)
ir_graph._register() ir_graph._register()
if id(module) not in modules_arg:
raise RuntimeError(f'{original_type_name} arguments are not recorded, \
you might have forgotten to decorate this class with @register_module()')
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return ir_graph, {} return ir_graph, {}
def convert_to_graph(script_module, module, recorded_modules_arg): def convert_to_graph(script_module, module):
""" """
Convert module to our graph ir, i.e., build a ```Model``` type Convert module to our graph ir, i.e., build a ```Model``` type
...@@ -527,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg): ...@@ -527,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg):
the script module obtained with torch.jit.script the script module obtained with torch.jit.script
module : nn.Module module : nn.Module
the targeted module instance the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns Returns
-------
Model Model
the constructed IR model the constructed IR model
""" """
global modules_arg
modules_arg = recorded_modules_arg
model = Model(_internal=True) model = Model(_internal=True)
module_name = '_model' module_name = '_model'
convert_module(script_module, module, module_name, model) GraphConverter().convert_module(script_module, module, module_name, model)
return model return model
...@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__) ...@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__)
OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer) OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer)
@dataclass(init=False) @dataclass(init=False)
class RetiariiExeConfig(ConfigBase): class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None experiment_name: Optional[str] = None
...@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment): ...@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment):
except Exception as e: except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e raise e
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args) base_model_ir = convert_to_graph(script_module, self.base_model)
assert id(self.trainer) in self.recorded_module_args recorded_module_args = get_records()
trainer_config = self.recorded_module_args[id(self.trainer)] if id(self.trainer) not in recorded_module_args:
base_model.apply_trainer(trainer_config['modulename'], trainer_config['args']) raise KeyError('Your trainer is not found in registered classes. You might have forgotten to \
register your customized trainer with @register_trainer decorator.')
trainer_config = recorded_module_args[id(self.trainer)]
base_model_ir.apply_trainer(trainer_config['modulename'], trainer_config['args'])
# handle inline mutations # handle inline mutations
mutators = self._process_inline_mutation(base_model) mutators = self._process_inline_mutation(base_model_ir)
if mutators is not None and self.applied_mutators: if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \ raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \
do not use mutators when you use LayerChoice/InputChoice') do not use mutators when you use LayerChoice/InputChoice')
...@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment): ...@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment):
self.applied_mutators = mutators self.applied_mutators = mutators
_logger.info('Starting strategy...') _logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators)).start() Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)).start()
_logger.info('Strategy started!') _logger.info('Strategy started!')
def start(self, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False) -> None:
......
import inspect
import logging import logging
from typing import Any, List from typing import Any, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import add_record, version_larger_equal from ...utils import add_record, blackbox_module, uid, version_larger_equal
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'): ...@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'):
if version_larger_equal(torch.__version__, '1.7.0'): if version_larger_equal(torch.__version__, '1.7.0'):
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss']) __all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss'])
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'ChannelShuffle'
class LayerChoice(nn.Module): class LayerChoice(nn.Module):
def __init__(self, op_candidates, reduction=None, return_mask=False, key=None): def __init__(self, op_candidates, reduction=None, return_mask=False, key=None):
super(LayerChoice, self).__init__() super(LayerChoice, self).__init__()
self.candidate_ops = op_candidates self.op_candidates = op_candidates
self.label = key self.label = key if key is not None else f'layerchoice_{uid()}'
self.key = key # deprecated, for backward compatibility self.key = self.label # deprecated, for backward compatibility
for i, module in enumerate(op_candidates): # deprecated, for backward compatibility for i, module in enumerate(op_candidates): # deprecated, for backward compatibility
self.add_module(str(i), module) self.add_module(str(i), module)
if reduction or return_mask: if reduction or return_mask:
...@@ -66,8 +62,8 @@ class InputChoice(nn.Module): ...@@ -66,8 +62,8 @@ class InputChoice(nn.Module):
self.n_candidates = n_candidates self.n_candidates = n_candidates
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction self.reduction = reduction
self.label = key self.label = key if key is not None else f'inputchoice_{uid()}'
self.key = key # deprecated, for backward compatibility self.key = self.label # deprecated, for backward compatibility
if choose_from or return_mask: if choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!') _logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
...@@ -101,6 +97,7 @@ class Placeholder(nn.Module): ...@@ -101,6 +97,7 @@ class Placeholder(nn.Module):
class ChosenInputs(nn.Module): class ChosenInputs(nn.Module):
""" """
""" """
def __init__(self, chosen: List[int], reduction: str): def __init__(self, chosen: List[int], reduction: str):
super().__init__() super().__init__()
self.chosen = chosen self.chosen = chosen
...@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module): ...@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules # the following are pytorch modules
class Module(nn.Module): Module = nn.Module
def __init__(self):
super(Module, self).__init__()
class Sequential(nn.Sequential): class Sequential(nn.Sequential):
...@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList): ...@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList):
super(ModuleList, self).__init__(*args) super(ModuleList, self).__init__(*args)
def wrap_module(original_class): Identity = blackbox_module(nn.Identity)
orig_init = original_class.__init__ Linear = blackbox_module(nn.Linear)
argname_list = list(inspect.signature(original_class).parameters.keys()) Conv1d = blackbox_module(nn.Conv1d)
# Make copy of original __init__, so we can call it without recursion Conv2d = blackbox_module(nn.Conv2d)
Conv3d = blackbox_module(nn.Conv3d)
def __init__(self, *args, **kws): ConvTranspose1d = blackbox_module(nn.ConvTranspose1d)
full_args = {} ConvTranspose2d = blackbox_module(nn.ConvTranspose2d)
full_args.update(kws) ConvTranspose3d = blackbox_module(nn.ConvTranspose3d)
for i, arg in enumerate(args): Threshold = blackbox_module(nn.Threshold)
full_args[argname_list[i]] = arg ReLU = blackbox_module(nn.ReLU)
add_record(id(self), full_args) Hardtanh = blackbox_module(nn.Hardtanh)
ReLU6 = blackbox_module(nn.ReLU6)
orig_init(self, *args, **kws) # Call the original __init__ Sigmoid = blackbox_module(nn.Sigmoid)
Tanh = blackbox_module(nn.Tanh)
original_class.__init__ = __init__ # Set the class' __init__ to the new one Softmax = blackbox_module(nn.Softmax)
return original_class Softmax2d = blackbox_module(nn.Softmax2d)
LogSoftmax = blackbox_module(nn.LogSoftmax)
ELU = blackbox_module(nn.ELU)
Identity = wrap_module(nn.Identity) SELU = blackbox_module(nn.SELU)
Linear = wrap_module(nn.Linear) CELU = blackbox_module(nn.CELU)
Conv1d = wrap_module(nn.Conv1d) GLU = blackbox_module(nn.GLU)
Conv2d = wrap_module(nn.Conv2d) GELU = blackbox_module(nn.GELU)
Conv3d = wrap_module(nn.Conv3d) Hardshrink = blackbox_module(nn.Hardshrink)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d) LeakyReLU = blackbox_module(nn.LeakyReLU)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d) LogSigmoid = blackbox_module(nn.LogSigmoid)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d) Softplus = blackbox_module(nn.Softplus)
Threshold = wrap_module(nn.Threshold) Softshrink = blackbox_module(nn.Softshrink)
ReLU = wrap_module(nn.ReLU) MultiheadAttention = blackbox_module(nn.MultiheadAttention)
Hardtanh = wrap_module(nn.Hardtanh) PReLU = blackbox_module(nn.PReLU)
ReLU6 = wrap_module(nn.ReLU6) Softsign = blackbox_module(nn.Softsign)
Sigmoid = wrap_module(nn.Sigmoid) Softmin = blackbox_module(nn.Softmin)
Tanh = wrap_module(nn.Tanh) Tanhshrink = blackbox_module(nn.Tanhshrink)
Softmax = wrap_module(nn.Softmax) RReLU = blackbox_module(nn.RReLU)
Softmax2d = wrap_module(nn.Softmax2d) AvgPool1d = blackbox_module(nn.AvgPool1d)
LogSoftmax = wrap_module(nn.LogSoftmax) AvgPool2d = blackbox_module(nn.AvgPool2d)
ELU = wrap_module(nn.ELU) AvgPool3d = blackbox_module(nn.AvgPool3d)
SELU = wrap_module(nn.SELU) MaxPool1d = blackbox_module(nn.MaxPool1d)
CELU = wrap_module(nn.CELU) MaxPool2d = blackbox_module(nn.MaxPool2d)
GLU = wrap_module(nn.GLU) MaxPool3d = blackbox_module(nn.MaxPool3d)
GELU = wrap_module(nn.GELU) MaxUnpool1d = blackbox_module(nn.MaxUnpool1d)
Hardshrink = wrap_module(nn.Hardshrink) MaxUnpool2d = blackbox_module(nn.MaxUnpool2d)
LeakyReLU = wrap_module(nn.LeakyReLU) MaxUnpool3d = blackbox_module(nn.MaxUnpool3d)
LogSigmoid = wrap_module(nn.LogSigmoid) FractionalMaxPool2d = blackbox_module(nn.FractionalMaxPool2d)
Softplus = wrap_module(nn.Softplus) FractionalMaxPool3d = blackbox_module(nn.FractionalMaxPool3d)
Softshrink = wrap_module(nn.Softshrink) LPPool1d = blackbox_module(nn.LPPool1d)
MultiheadAttention = wrap_module(nn.MultiheadAttention) LPPool2d = blackbox_module(nn.LPPool2d)
PReLU = wrap_module(nn.PReLU) LocalResponseNorm = blackbox_module(nn.LocalResponseNorm)
Softsign = wrap_module(nn.Softsign) BatchNorm1d = blackbox_module(nn.BatchNorm1d)
Softmin = wrap_module(nn.Softmin) BatchNorm2d = blackbox_module(nn.BatchNorm2d)
Tanhshrink = wrap_module(nn.Tanhshrink) BatchNorm3d = blackbox_module(nn.BatchNorm3d)
RReLU = wrap_module(nn.RReLU) InstanceNorm1d = blackbox_module(nn.InstanceNorm1d)
AvgPool1d = wrap_module(nn.AvgPool1d) InstanceNorm2d = blackbox_module(nn.InstanceNorm2d)
AvgPool2d = wrap_module(nn.AvgPool2d) InstanceNorm3d = blackbox_module(nn.InstanceNorm3d)
AvgPool3d = wrap_module(nn.AvgPool3d) LayerNorm = blackbox_module(nn.LayerNorm)
MaxPool1d = wrap_module(nn.MaxPool1d) GroupNorm = blackbox_module(nn.GroupNorm)
MaxPool2d = wrap_module(nn.MaxPool2d) SyncBatchNorm = blackbox_module(nn.SyncBatchNorm)
MaxPool3d = wrap_module(nn.MaxPool3d) Dropout = blackbox_module(nn.Dropout)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d) Dropout2d = blackbox_module(nn.Dropout2d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d) Dropout3d = blackbox_module(nn.Dropout3d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d) AlphaDropout = blackbox_module(nn.AlphaDropout)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d) FeatureAlphaDropout = blackbox_module(nn.FeatureAlphaDropout)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d) ReflectionPad1d = blackbox_module(nn.ReflectionPad1d)
LPPool1d = wrap_module(nn.LPPool1d) ReflectionPad2d = blackbox_module(nn.ReflectionPad2d)
LPPool2d = wrap_module(nn.LPPool2d) ReplicationPad2d = blackbox_module(nn.ReplicationPad2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm) ReplicationPad1d = blackbox_module(nn.ReplicationPad1d)
BatchNorm1d = wrap_module(nn.BatchNorm1d) ReplicationPad3d = blackbox_module(nn.ReplicationPad3d)
BatchNorm2d = wrap_module(nn.BatchNorm2d) CrossMapLRN2d = blackbox_module(nn.CrossMapLRN2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d) Embedding = blackbox_module(nn.Embedding)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d) EmbeddingBag = blackbox_module(nn.EmbeddingBag)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d) RNNBase = blackbox_module(nn.RNNBase)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d) RNN = blackbox_module(nn.RNN)
LayerNorm = wrap_module(nn.LayerNorm) LSTM = blackbox_module(nn.LSTM)
GroupNorm = wrap_module(nn.GroupNorm) GRU = blackbox_module(nn.GRU)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm) RNNCellBase = blackbox_module(nn.RNNCellBase)
Dropout = wrap_module(nn.Dropout) RNNCell = blackbox_module(nn.RNNCell)
Dropout2d = wrap_module(nn.Dropout2d) LSTMCell = blackbox_module(nn.LSTMCell)
Dropout3d = wrap_module(nn.Dropout3d) GRUCell = blackbox_module(nn.GRUCell)
AlphaDropout = wrap_module(nn.AlphaDropout) PixelShuffle = blackbox_module(nn.PixelShuffle)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout) Upsample = blackbox_module(nn.Upsample)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d) UpsamplingNearest2d = blackbox_module(nn.UpsamplingNearest2d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d) UpsamplingBilinear2d = blackbox_module(nn.UpsamplingBilinear2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d) PairwiseDistance = blackbox_module(nn.PairwiseDistance)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d) AdaptiveMaxPool1d = blackbox_module(nn.AdaptiveMaxPool1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d) AdaptiveMaxPool2d = blackbox_module(nn.AdaptiveMaxPool2d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d) AdaptiveMaxPool3d = blackbox_module(nn.AdaptiveMaxPool3d)
Embedding = wrap_module(nn.Embedding) AdaptiveAvgPool1d = blackbox_module(nn.AdaptiveAvgPool1d)
EmbeddingBag = wrap_module(nn.EmbeddingBag) AdaptiveAvgPool2d = blackbox_module(nn.AdaptiveAvgPool2d)
RNNBase = wrap_module(nn.RNNBase) AdaptiveAvgPool3d = blackbox_module(nn.AdaptiveAvgPool3d)
RNN = wrap_module(nn.RNN) TripletMarginLoss = blackbox_module(nn.TripletMarginLoss)
LSTM = wrap_module(nn.LSTM) ZeroPad2d = blackbox_module(nn.ZeroPad2d)
GRU = wrap_module(nn.GRU) ConstantPad1d = blackbox_module(nn.ConstantPad1d)
RNNCellBase = wrap_module(nn.RNNCellBase) ConstantPad2d = blackbox_module(nn.ConstantPad2d)
RNNCell = wrap_module(nn.RNNCell) ConstantPad3d = blackbox_module(nn.ConstantPad3d)
LSTMCell = wrap_module(nn.LSTMCell) Bilinear = blackbox_module(nn.Bilinear)
GRUCell = wrap_module(nn.GRUCell) CosineSimilarity = blackbox_module(nn.CosineSimilarity)
PixelShuffle = wrap_module(nn.PixelShuffle) Unfold = blackbox_module(nn.Unfold)
Upsample = wrap_module(nn.Upsample) Fold = blackbox_module(nn.Fold)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d) AdaptiveLogSoftmaxWithLoss = blackbox_module(nn.AdaptiveLogSoftmaxWithLoss)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d) TransformerEncoder = blackbox_module(nn.TransformerEncoder)
PairwiseDistance = wrap_module(nn.PairwiseDistance) TransformerDecoder = blackbox_module(nn.TransformerDecoder)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d) TransformerEncoderLayer = blackbox_module(nn.TransformerEncoderLayer)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d) TransformerDecoderLayer = blackbox_module(nn.TransformerDecoderLayer)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d) Transformer = blackbox_module(nn.Transformer)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d) Flatten = blackbox_module(nn.Flatten)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d) Hardsigmoid = blackbox_module(nn.Hardsigmoid)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
Flatten = wrap_module(nn.Flatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'): if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = wrap_module(nn.Hardswish) Hardswish = blackbox_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'): if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = wrap_module(nn.SiLU) SiLU = blackbox_module(nn.SiLU)
Unflatten = wrap_module(nn.Unflatten) Unflatten = blackbox_module(nn.Unflatten)
TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss) TripletMarginWithDistanceLoss = blackbox_module(nn.TripletMarginWithDistanceLoss)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
\ No newline at end of file
...@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any: ...@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
return None return None
@register_trainer() @register_trainer
class PyTorchImageClassificationTrainer(BaseTrainer): class PyTorchImageClassificationTrainer(BaseTrainer):
""" """
Image classification trainer for PyTorch. Image classification trainer for PyTorch.
...@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer): ...@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently, Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful. only the key ``max_epochs`` is useful.
""" """
super(PyTorchImageClassificationTrainer, self).__init__() super().__init__()
self._use_cuda = torch.cuda.is_available() self._use_cuda = torch.cuda.is_available()
self.model = model self.model = model
if self._use_cuda: if self._use_cuda:
......
import inspect import inspect
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
...@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any: ...@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any:
module = __import__(path, globals(), locals(), [identifier]) module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier) return getattr(module, identifier)
def version_larger_equal(a: str, b: str) -> bool: def version_larger_equal(a: str, b: str) -> bool:
# TODO: refactor later # TODO: refactor later
a = a.split('+')[0] a = a.split('+')[0]
b = b.split('+')[0] b = b.split('+')[0]
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.'))) return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
_records = {} _records = {}
...@@ -29,73 +32,87 @@ def add_record(key, value): ...@@ -29,73 +32,87 @@ def add_record(key, value):
""" """
global _records global _records
if _records is not None: if _records is not None:
#assert key not in _records, '{} already in _records'.format(key) assert key not in _records, '{} already in _records'.format(key)
_records[key] = value _records[key] = value
def _register_module(original_class): def del_record(key):
orig_init = original_class.__init__ global _records
argname_list = list(inspect.signature(original_class).parameters.keys()) if _records is not None:
# Make copy of original __init__, so we can call it without recursion _records.pop(key, None)
def __init__(self, *args, **kws): def _blackbox_cls(cls, module_name, register_format=None):
class wrapper(cls):
def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls).parameters.keys())
full_args = {} full_args = {}
full_args.update(kws) full_args.update(kwargs)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# eject un-serializable arguments
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases.
if not isinstance(full_args[k], (int, float, str, dict, list)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
This is not supported. You can ignore this warning if you are passing the model to trainer.')
full_args.pop(k)
if register_format == 'args':
add_record(id(self), full_args) add_record(id(self), full_args)
elif register_format == 'full':
full_class_name = cls.__module__ + '.' + cls.__name__
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__ super().__init__(*args, **kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def register_module(): def __del__(self):
""" del_record(id(self))
Register a module.
"""
# use it as a decorator: @register_module()
def _register(cls):
m = _register_module(
original_class=cls)
return m
return _register # using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
# instead of simply putting torch.nn or etc.
wrapper.__module__ = module_name
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
return wrapper
def _register_trainer(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
full_class_name = original_class.__module__ + '.' + original_class.__name__ def blackbox(cls, *args, **kwargs):
"""
def __init__(self, *args, **kws): To create an blackbox instance inline without decorator. For example,
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
# TODO: support both pytorch and tensorflow
from .nn.pytorch import Module
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__ .. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')(*args, **kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def blackbox_module(cls):
"""
Register a module. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')
def register_trainer():
def _register(cls):
m = _register_trainer(
original_class=cls)
return m
return _register def register_trainer(cls):
"""
Register a trainer. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'full')
_last_uid = defaultdict(int) _last_uid = defaultdict(int)
......
...@@ -5,6 +5,7 @@ tuner_result.txt ...@@ -5,6 +5,7 @@ tuner_result.txt
assessor_result.txt assessor_result.txt
_generated_model.py _generated_model.py
_generated_model_*.py
data data
generated generated
...@@ -7,9 +7,9 @@ import torch.nn as torch_nn ...@@ -7,9 +7,9 @@ import torch.nn as torch_nn
import ops import ops
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module from nni.retiarii import blackbox_module
@blackbox_module
class AuxiliaryHead(nn.Module): class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """ """ Auxiliary head in 2/3 place of network to let the gradient flow well """
...@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module): ...@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module):
logits = self.linear(out) logits = self.linear(out)
return logits return logits
@register_module()
class Node(nn.Module): class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__() super().__init__()
...@@ -66,7 +65,6 @@ class Node(nn.Module): ...@@ -66,7 +65,6 @@ class Node(nn.Module):
#out = [self.drop_path(o) if o is not None else None for o in out] #out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out) return self.input_switch(out)
@register_module()
class Cell(nn.Module): class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
...@@ -100,7 +98,6 @@ class Cell(nn.Module): ...@@ -100,7 +98,6 @@ class Cell(nn.Module):
output = torch.cat(new_tensors, dim=1) output = torch.cat(new_tensors, dim=1)
return output return output
@register_module()
class CNN(nn.Module): class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
......
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module from nni.retiarii import blackbox_module
@register_module() @blackbox_module
class DropPath(nn.Module): class DropPath(nn.Module):
def __init__(self, p=0.): def __init__(self, p=0.):
""" """
...@@ -12,7 +12,7 @@ class DropPath(nn.Module): ...@@ -12,7 +12,7 @@ class DropPath(nn.Module):
p : float p : float
Probability of an path to be zeroed. Probability of an path to be zeroed.
""" """
super(DropPath, self).__init__() super().__init__()
self.p = p self.p = p
def forward(self, x): def forward(self, x):
...@@ -24,13 +24,13 @@ class DropPath(nn.Module): ...@@ -24,13 +24,13 @@ class DropPath(nn.Module):
return x return x
@register_module() @blackbox_module
class PoolBN(nn.Module): class PoolBN(nn.Module):
""" """
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`. AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
""" """
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super(PoolBN, self).__init__() super().__init__()
if pool_type.lower() == 'max': if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding) self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg': elif pool_type.lower() == 'avg':
...@@ -45,13 +45,13 @@ class PoolBN(nn.Module): ...@@ -45,13 +45,13 @@ class PoolBN(nn.Module):
out = self.bn(out) out = self.bn(out)
return out return out
@register_module() @blackbox_module
class StdConv(nn.Module): class StdConv(nn.Module):
""" """
Standard conv: ReLU - Conv - BN Standard conv: ReLU - Conv - BN
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(StdConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
...@@ -61,13 +61,13 @@ class StdConv(nn.Module): ...@@ -61,13 +61,13 @@ class StdConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class FacConv(nn.Module): class FacConv(nn.Module):
""" """
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
""" """
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super(FacConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
...@@ -78,7 +78,7 @@ class FacConv(nn.Module): ...@@ -78,7 +78,7 @@ class FacConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class DilConv(nn.Module): class DilConv(nn.Module):
""" """
(Dilated) depthwise separable conv. (Dilated) depthwise separable conv.
...@@ -86,7 +86,7 @@ class DilConv(nn.Module): ...@@ -86,7 +86,7 @@ class DilConv(nn.Module):
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field. If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
...@@ -98,14 +98,14 @@ class DilConv(nn.Module): ...@@ -98,14 +98,14 @@ class DilConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class SepConv(nn.Module): class SepConv(nn.Module):
""" """
Depthwise separable conv. Depthwise separable conv.
DilConv(dilation=1) * 2. DilConv(dilation=1) * 2.
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
...@@ -114,13 +114,13 @@ class SepConv(nn.Module): ...@@ -114,13 +114,13 @@ class SepConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class FactorizedReduce(nn.Module): class FactorizedReduce(nn.Module):
""" """
Reduce feature map size by factorized pointwise (stride=2). Reduce feature map size by factorized pointwise (stride=2).
""" """
def __init__(self, C_in, C_out, affine=True): def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__() super().__init__()
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
......
...@@ -31,4 +31,4 @@ if __name__ == '__main__': ...@@ -31,4 +31,4 @@ if __name__ == '__main__':
exp_config.training_service.use_active_gpu = True exp_config.training_service.use_active_gpu = True
exp_config.training_service.gpu_indices = [1, 2] exp_config.training_service.gpu_indices = [1, 2]
exp.run(exp_config, 8081, debug=True) exp.run(exp_config, 8081)
...@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0): ...@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0):
valid_transform = transforms.Compose(normalize) valid_transform = transforms.Compose(normalize)
if cls == "cifar10": if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) dataset_train = CIFAR10(root="./data/cifar10", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) dataset_valid = CIFAR10(root="./data/cifar10", train=False, download=True, transform=valid_transform)
else: else:
raise NotImplementedError raise NotImplementedError
return dataset_train, dataset_valid return dataset_train, dataset_valid
......
from nni.retiarii import blackbox_module
import nni.retiarii.nn.pytorch as nn
import warnings import warnings
import torch import torch
...@@ -8,8 +10,6 @@ import torch.nn.functional as F ...@@ -8,8 +10,6 @@ import torch.nn.functional as F
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2])) sys.path.append(str(Path(__file__).resolve().parents[2]))
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow. # 1.0 - tensorflow.
...@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module): ...@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) + x return self.net(x) + x
class _InvertedResidual(nn.Module): class _InvertedResidual(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, skip, bn_momentum=0.1): def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, skip, bn_momentum=0.1):
...@@ -110,7 +111,7 @@ def _get_depths(depths, alpha): ...@@ -110,7 +111,7 @@ def _get_depths(depths, alpha):
rather than down. """ rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
@register_module()
class MNASNet(nn.Module): class MNASNet(nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model. implements the B1 variant of the model.
...@@ -127,7 +128,7 @@ class MNASNet(nn.Module): ...@@ -127,7 +128,7 @@ class MNASNet(nn.Module):
def __init__(self, alpha, depths, convops, kernel_sizes, num_layers, def __init__(self, alpha, depths, convops, kernel_sizes, num_layers,
skips, num_classes=1000, dropout=0.2): skips, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__() super().__init__()
assert alpha > 0.0 assert alpha > 0.0
assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7 assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7
self.alpha = alpha self.alpha = alpha
...@@ -143,7 +144,7 @@ class MNASNet(nn.Module): ...@@ -143,7 +144,7 @@ class MNASNet(nn.Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
] ]
count = 0 count = 0
#for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \ # for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios): # zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides): for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides):
# TODO: restrict that "choose" can only be used within mutator # TODO: restrict that "choose" can only be used within mutator
...@@ -153,7 +154,7 @@ class MNASNet(nn.Module): ...@@ -153,7 +154,7 @@ class MNASNet(nn.Module):
'op_type_options': ['__mutated__.base_mnasnet.RegularConv', 'op_type_options': ['__mutated__.base_mnasnet.RegularConv',
'__mutated__.base_mnasnet.DepthwiseConv', '__mutated__.base_mnasnet.DepthwiseConv',
'__mutated__.base_mnasnet.MobileConv'], '__mutated__.base_mnasnet.MobileConv'],
#'se_ratio_options': [0, 0.25], # 'se_ratio_options': [0, 0.25],
'skip_options': ['identity', 'no'], 'skip_options': ['identity', 'no'],
'n_filter_options': [int(filter_size*x) for x in [0.75, 1.0, 1.25]], 'n_filter_options': [int(filter_size*x) for x in [0.75, 1.0, 1.25]],
'exp_ratio': exp_ratio, 'exp_ratio': exp_ratio,
...@@ -185,7 +186,7 @@ class MNASNet(nn.Module): ...@@ -185,7 +186,7 @@ class MNASNet(nn.Module):
#self.for_test = 10 #self.for_test = 10
def forward(self, x): def forward(self, x):
#if self.for_test == 10: # if self.for_test == 10:
x = self.layers(x) x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions. # Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3]) x = x.mean([2, 3])
...@@ -211,9 +212,11 @@ class MNASNet(nn.Module): ...@@ -211,9 +212,11 @@ class MNASNet(nn.Module):
def test_model(model): def test_model(model):
model(torch.randn(2, 3, 224, 224)) model(torch.randn(2, 3, 224, 224))
#====================definition of candidate op classes
# ====================definition of candidate op classes
BN_MOMENTUM = 1 - 0.9997 BN_MOMENTUM = 1 - 0.9997
class RegularConv(nn.Module): class RegularConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride): def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__() super().__init__()
...@@ -234,6 +237,7 @@ class RegularConv(nn.Module): ...@@ -234,6 +237,7 @@ class RegularConv(nn.Module):
out = out + x out = out + x
return out return out
class DepthwiseConv(nn.Module): class DepthwiseConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride): def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__() super().__init__()
...@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module): ...@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module):
out = out + x out = out + x
return out return out
class MobileConv(nn.Module): class MobileConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride): def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__() super().__init__()
...@@ -274,7 +279,7 @@ class MobileConv(nn.Module): ...@@ -274,7 +279,7 @@ class MobileConv(nn.Module):
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM), nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# Depthwise # Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding= (kernel_size - 1) // 2, nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=(kernel_size - 1) // 2,
stride=stride, groups=mid_ch, bias=False), stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM), nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
...@@ -288,5 +293,6 @@ class MobileConv(nn.Module): ...@@ -288,5 +293,6 @@ class MobileConv(nn.Module):
out = out + x out = out + x
return out return out
# mnasnet0_5 # mnasnet0_5
ir_module = _InvertedResidual(16, 16, 3, 1, 1, True) ir_module = _InvertedResidual(16, 16, 3, 1, 1, True)
...@@ -41,4 +41,4 @@ if __name__ == '__main__': ...@@ -41,4 +41,4 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10 exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081, debug=True) exp.run(exp_config, 8081)
import random
import nni.retiarii.nn.pytorch as nn
import torch.nn.functional as F
from nni.retiarii.experiment import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.strategies import RandomStrategy
from nni.retiarii.trainer import PyTorchImageClassificationTrainer
class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice([
nn.Linear(4*4*50, hidden_size),
nn.Linear(4*4*50, hidden_size, bias=False)
])
self.fc2 = nn.Linear(hidden_size, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
if __name__ == '__main__':
base_model = Net(128)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="MNIST",
dataset_kwargs={"root": "data/mnist", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
simple_startegy = RandomStrategy()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081 + random.randint(0, 100))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment