Unverified Commit ae50ed14 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Refactor wrap module as "blackbox_module" (#3238)

parent 15da19d3
...@@ -2,4 +2,4 @@ from .operation import Operation ...@@ -2,4 +2,4 @@ from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
from .mutator import * from .mutator import *
from .utils import register_module from .utils import blackbox, blackbox_module, register_trainer
\ No newline at end of file
...@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str: ...@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
def _sorted_incoming_edges(node: Node) -> List[Edge]: def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node] edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: %s', str(edges)) _logger.debug('sorted_incoming_edges: %s', str(edges))
if not edges: if not edges:
return [] return []
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges])) _logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges): if all(edge.tail_slot is None for edge in edges):
return edges return edges
if all(isinstance(edge.tail_slot, int) for edge in edges): if all(isinstance(edge.tail_slot, int) for edge in edges):
......
...@@ -6,518 +6,501 @@ import torch ...@@ -6,518 +6,501 @@ import torch
from ..graph import Graph, Model, Node from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell from ..operation import Cell
from ..utils import get_records
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .utils import _convert_name, build_full_name from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
global_seq = 0
global_graph_id = 0
modules_arg = None
class GraphConverter:
def __init__(self):
self.global_seq = 0
self.global_graph_id = 0
self.modules_arg = get_records()
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False): def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
""" """
Parameters Parameters
---------- ----------
ir_graph : Graph ir_graph : Graph
node : torch._C.Node node : torch._C.Node
graph_inputs : List[torch._C.Value] graph_inputs : List[torch._C.Value]
a list of a script graph's inputs a list of a script graph's inputs
node_index : Dict node_index : Dict
new_node : Node new_node : Node
newly created ir node corresponding to `node` newly created ir node corresponding to `node`
output_remap : Dict output_remap : Dict
ignore_first : bool ignore_first : bool
if it is true, skip the first input if it is true, skip the first input
""" """
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1 is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0 new_node_input_idx = 0
for _input in node.inputs(): for _input in node.inputs():
if ignore_first: if ignore_first:
ignore_first = False ignore_first = False
continue continue
# handle source node # handle source node
if _input in graph_inputs: if _input in graph_inputs:
idx = graph_inputs.index(_input) idx = graph_inputs.index(_input)
src_node = ir_graph.input_node src_node = ir_graph.input_node
src_node_idx = idx src_node_idx = idx
elif _input in output_remap: elif _input in output_remap:
assert output_remap[_input].kind() == 'aten::append' assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input] predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node) assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None src_node_idx = None
src_node = node_index[predecessor_node] src_node = node_index[predecessor_node]
assert isinstance(src_node, Node) assert isinstance(src_node, Node)
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs = [_output for _output in predecessor_node.outputs()]
if len(predecessor_outputs) == 1:
idx = None
else: else:
idx = predecessor_outputs.index(_input) predecessor_node = _input.node()
ir_predecessor_node = node_index[predecessor_node] assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = idx # find out the index of _input in the outputs of predecessor_node
assert isinstance(ir_predecessor_node, Node) predecessor_outputs = [_output for _output in predecessor_node.outputs()]
src_node = ir_predecessor_node if len(predecessor_outputs) == 1:
idx = None
# handle destination node else:
dst_node = new_node idx = predecessor_outputs.index(_input)
if is_single_input: ir_predecessor_node = node_index[predecessor_node]
dst_node_idx = None src_node_idx = idx
else: assert isinstance(ir_predecessor_node, Node)
dst_node_idx = new_node_input_idx src_node = ir_predecessor_node
# handle destination node
dst_node = new_node
if is_single_input:
dst_node_idx = None
else:
dst_node_idx = new_node_input_idx
# create edge # create edge
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx)) ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
new_node_input_idx += 1 new_node_input_idx += 1
def create_prim_constant_node(self, ir_graph, node, module_name):
attrs = {}
if node.outputsAt(0).toIValue() is not None:
attrs = {'value': node.outputsAt(0).toIValue()}
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, self.global_seq),
node.kind(), attrs)
return new_node
def create_prim_constant_node(ir_graph, node, module_name): def handle_prim_attr_node(self, node):
global global_seq assert node.hasAttribute('name')
attrs = {} attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
if node.outputsAt(0).toIValue() is not None: return node.kind(), attrs
attrs = {'value': node.outputsAt(0).toIValue()}
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, global_seq),
node.kind(), attrs)
return new_node
def _remove_mangle(self, module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def handle_prim_attr_node(node): def remove_unconnected_nodes(self, ir_graph, targeted_type=None):
assert node.hasAttribute('name') """
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()} Parameters
return node.kind(), attrs ----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
if targeted_type is None:
to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(self, script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
def _remove_mangle(module_type_str): Returns
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str) -------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs = []
for _input in sm_graph.inputs():
if _input.debugName() == 'self':
assert _input.unique() == 0
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap = {}
def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition')
expr = _generate_expr(cond_tensor)
return eval(expr)
def handle_if_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
return last_block_node
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if node.kind() == 'prim::CallMethod':
# get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name')
if node.s('name') == 'forward':
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = self._remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node()
assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name')
if submodule.inputsAt(0).debugName() == 'self':
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = self.convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList
predecessor = submodule.inputsAt(0).node()
assert predecessor.kind() == 'prim::GetAttr'
assert predecessor.hasAttribute('name')
assert predecessor.inputsAt(0).debugName() == 'self'
predecessor_name = predecessor.s('name')
# FIXME: exchange
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name])
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = self.convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
# TODO: match subgraph with maintained graphs
# build cell
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, (LayerChoice, InputChoice)):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
node_index[node] = subcell
# connect the cell into graph
self._add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else:
raise RuntimeError('unsupported CallMethod {}'.format(node.s('name')))
elif node.kind() == 'prim::CallFunction':
func_type_str = self._remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
self.global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
'{}.{}'.format(func_type_str, func_name))
node_index[node] = func_node
self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
new_node = self.create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, self.global_seq), node.kind())
node_index[node] = new_node
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
self.global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
self.global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = self.handle_prim_attr_node(node)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, self.global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
return node_index[node]
def remove_unconnected_nodes(ir_graph, targeted_type=None): for node in sm_graph.nodes():
""" handle_single_node(node)
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
if targeted_type is None:
to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters return node_index
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
Returns def merge_aten_slices(self, ir_graph):
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs = []
for _input in sm_graph.inputs():
if _input.debugName() == 'self':
assert _input.unique() == 0
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap = {}
def handle_if_condition(cond_tensor):
""" """
to calculate the condition, we only deal with the following op types by tracing back if there is aten::slice node, merge the consecutive ones together.
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq` ```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
generate the expression using recursive calls
NOTE: do not support dynamic graph
""" """
def _generate_expr(tensor): head_slice_nodes = []
if tensor.node().kind() == 'prim::GetAttr': has_slice_node = False
return f'({getattr(module, tensor.node().s("name"))})' for node in ir_graph.hidden_nodes:
elif tensor.node().kind() == 'aten::__getitem__': if node.operation.type == 'aten::slice':
t = _generate_expr(tensor.node().inputsAt(0)) has_slice_node = True
idx = _generate_expr(tensor.node().inputsAt(1)) for pred in node.predecessors:
return f'({t}[{idx}])' if pred.operation.type not in ['aten::slice', 'prim::Constant']:
elif tensor.node().kind() == 'prim::Constant': head_slice_nodes.append(node)
return f'{tensor.toIValue()}' break
elif tensor.node().kind() == 'aten::eq': if has_slice_node:
left = _generate_expr(tensor.node().inputsAt(0)) assert head_slice_nodes
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})' for head_node in head_slice_nodes:
else: slot = 0
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition') new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
expr = _generate_expr(cond_tensor) if len(head_node.incoming_edges) == 4:
return eval(expr) # when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
def handle_if_node(node): def refine_graph(self, ir_graph):
""" """
Parameters Do the following process to simplify graph:
---------- 1. remove unconnected constant node
node : torch._C.Node 2. remove unconnected getattr node
the node from TorchScript graph
Returns
-------
Node
the created node ir
""" """
# only deal with input of prim::If is constant or attribute for now # some constant is not used, for example, function name as prim::Constant
# will support constant expression in future self.remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
inputs = [i for i in node.inputs()] self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
assert len(inputs) == 1 self.merge_aten_slices(ir_graph)
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1 def _handle_layerchoice(self, module):
blocks = [block for block in node.blocks()] m_attrs = {}
assert len(blocks) == 2 candidates = module.op_candidates
last_block_node = None choices = []
for node in blocks[chosen_block].nodes(): for cand in candidates:
last_block_node = handle_single_node(node) assert id(cand) in self.modules_arg, 'id not exist: {}'.format(id(cand))
return last_block_node assert isinstance(self.modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
def handle_single_node(node): choices.append({'type': cand_type, 'parameters': self.modules_arg[id(cand)]})
m_attrs[f'choices'] = choices
m_attrs['label'] = module.label
return m_attrs
def _handle_inputchoice(self, module):
m_attrs = {}
m_attrs['n_candidates'] = module.n_candidates
m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label
return m_attrs
def convert_module(self, script_module, module, module_name, ir_model):
""" """
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters Parameters
---------- ----------
node : torch._C.Node script_module : torch.jit.RecursiveScriptModule
the node from TorchScript graph the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns Returns
------- -------
Node Graph
the created node ir the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
""" """
global global_seq
if node.kind() == 'prim::CallMethod': # NOTE: have not supported nested LayerChoice, i.e., a candidate module
# get and handle the first input, which should be an nn.Module # also has LayerChoice or InputChoice or ValueChoice
assert node.hasAttribute('name') original_type_name = script_module.original_name
if node.s('name') == 'forward': m_attrs = None
# node.inputsAt(0).type() is <class 'torch._C.ClassType'> if original_type_name in MODULE_EXCEPT_LIST:
submodule_type_str = _remove_mangle(node.inputsAt(0).type().str()) pass # do nothing
submodule = node.inputsAt(0).node() elif original_type_name == OpTypeName.LayerChoice:
assert submodule.kind() == 'prim::GetAttr' m_attrs = self._handle_layerchoice(module)
assert submodule.hasAttribute('name') elif original_type_name == OpTypeName.InputChoice:
submodule_name = submodule.s('name') m_attrs = self._handle_inputchoice(module)
elif original_type_name == OpTypeName.Placeholder:
if submodule.inputsAt(0).debugName() == 'self': m_attrs = self.modules_arg[id(module)]
# module is usually instantiated in __init__. elif original_type_name in torch.nn.__dict__:
# when calling a module in forward, # this is a basic module from pytorch, no need to parse its graph
# prim::GetAttr is used to obtain the module in torch script. assert id(module) in self.modules_arg, f'{original_type_name} arguments are not recorded'
# therefore, we do this check for a module. example below: m_attrs = self.modules_arg[id(module)]
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self) elif id(module) in self.modules_arg:
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1) # this module is marked as blackbox, won't continue to parse
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format( m_attrs = self.modules_arg[id(module)]
submodule_name, script_module._modules.keys()) if m_attrs is not None:
return None, m_attrs
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name) # handle TorchScript graph
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name], sm_graph = script_module.graph
submodule_obj, self.global_graph_id += 1
submodule_full_name, ir_model) ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self) # handle graph nodes
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8) node_index = self.handle_graph_nodes(script_module, sm_graph, module,
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4) module_name, ir_model, ir_graph)
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList # handle graph outputs
predecessor = submodule.inputsAt(0).node() for _output in sm_graph.outputs():
assert predecessor.kind() == 'prim::GetAttr' ir_graph._add_output(_convert_name(_output.debugName()))
assert predecessor.hasAttribute('name') predecessor_node_outputs = [o for o in _output.node().outputs()]
assert predecessor.inputsAt(0).debugName() == 'self' if len(predecessor_node_outputs) == 1:
predecessor_name = predecessor.s('name') src_node_idx = None
# FIXME: exchange
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name])
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
# TODO: match subgraph with maintained graphs
# build cell
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, (LayerChoice, InputChoice)):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
node_index[node] = subcell
# connect the cell into graph
_add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else: else:
raise RuntimeError('unsupported CallMethod {}'.format(node.s('name'))) src_node_idx = predecessor_node_outputs.index(_output)
elif node.kind() == 'prim::CallFunction': ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
func_type_str = _remove_mangle(node.inputsAt(0).type().str()) tail=(ir_graph.output_node, None))
func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, global_seq),
'{}.{}'.format(func_type_str, func_name))
node_index[node] = func_node
_add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
new_node = create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, global_seq), node.kind())
node_index[node] = new_node
_add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = handle_prim_attr_node(node)
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
return node_index[node]
for node in sm_graph.nodes():
handle_single_node(node)
return node_index
def merge_aten_slices(ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes = []
has_slice_node = False
for node in ir_graph.hidden_nodes:
if node.operation.type == 'aten::slice':
has_slice_node = True
for pred in node.predecessors:
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
head_slice_nodes.append(node)
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges: self.refine_graph(ir_graph)
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
ir_graph._register()
def refine_graph(ir_graph): return ir_graph, {}
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph)
def _handle_layerchoice(module):
global modules_arg
m_attrs = {}
candidates = module.candidate_ops
choices = []
for cand in candidates:
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': modules_arg[id(cand)]})
m_attrs[f'choices'] = choices
m_attrs['label'] = module.label
return m_attrs
def _handle_inputchoice(module):
m_attrs = {}
m_attrs['n_candidates'] = module.n_candidates
m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label
return m_attrs
def convert_module(script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns def convert_to_graph(script_module, module):
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
global global_graph_id
global modules_arg
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
if original_type_name == OpTypeName.LayerChoice:
m_attrs = _handle_layerchoice(module)
return None, m_attrs
if original_type_name == OpTypeName.InputChoice:
m_attrs = _handle_inputchoice(module)
return None, m_attrs
if original_type_name == OpTypeName.Placeholder:
m_attrs = modules_arg[id(module)]
return None, m_attrs
if original_type_name in torch.nn.__dict__ and original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
assert id(module) in modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = modules_arg[id(module)]
return None, m_attrs
# handle TorchScript graph
sm_graph = script_module.graph
global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=global_graph_id, name=module_name, _internal=True)
# handle graph nodes
node_index = handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph)
# handle graph outputs
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
refine_graph(ir_graph)
ir_graph._register()
if id(module) not in modules_arg:
raise RuntimeError(f'{original_type_name} arguments are not recorded, \
you might have forgotten to decorate this class with @register_module()')
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return ir_graph, {}
def convert_to_graph(script_module, module, recorded_modules_arg):
""" """
Convert module to our graph ir, i.e., build a ```Model``` type Convert module to our graph ir, i.e., build a ```Model``` type
...@@ -527,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg): ...@@ -527,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg):
the script module obtained with torch.jit.script the script module obtained with torch.jit.script
module : nn.Module module : nn.Module
the targeted module instance the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns Returns
-------
Model Model
the constructed IR model the constructed IR model
""" """
global modules_arg
modules_arg = recorded_modules_arg
model = Model(_internal=True) model = Model(_internal=True)
module_name = '_model' module_name = '_model'
convert_module(script_module, module, module_name, model) GraphConverter().convert_module(script_module, module, module_name, model)
return model return model
...@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__) ...@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__)
OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer) OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer)
@dataclass(init=False) @dataclass(init=False)
class RetiariiExeConfig(ConfigBase): class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None experiment_name: Optional[str] = None
...@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment): ...@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment):
except Exception as e: except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e raise e
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args) base_model_ir = convert_to_graph(script_module, self.base_model)
assert id(self.trainer) in self.recorded_module_args recorded_module_args = get_records()
trainer_config = self.recorded_module_args[id(self.trainer)] if id(self.trainer) not in recorded_module_args:
base_model.apply_trainer(trainer_config['modulename'], trainer_config['args']) raise KeyError('Your trainer is not found in registered classes. You might have forgotten to \
register your customized trainer with @register_trainer decorator.')
trainer_config = recorded_module_args[id(self.trainer)]
base_model_ir.apply_trainer(trainer_config['modulename'], trainer_config['args'])
# handle inline mutations # handle inline mutations
mutators = self._process_inline_mutation(base_model) mutators = self._process_inline_mutation(base_model_ir)
if mutators is not None and self.applied_mutators: if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \ raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \
do not use mutators when you use LayerChoice/InputChoice') do not use mutators when you use LayerChoice/InputChoice')
...@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment): ...@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment):
self.applied_mutators = mutators self.applied_mutators = mutators
_logger.info('Starting strategy...') _logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators)).start() Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)).start()
_logger.info('Strategy started!') _logger.info('Strategy started!')
def start(self, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False) -> None:
......
import inspect
import logging import logging
from typing import Any, List from typing import Any, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import add_record, version_larger_equal from ...utils import add_record, blackbox_module, uid, version_larger_equal
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'): ...@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'):
if version_larger_equal(torch.__version__, '1.7.0'): if version_larger_equal(torch.__version__, '1.7.0'):
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss']) __all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss'])
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'ChannelShuffle'
class LayerChoice(nn.Module): class LayerChoice(nn.Module):
def __init__(self, op_candidates, reduction=None, return_mask=False, key=None): def __init__(self, op_candidates, reduction=None, return_mask=False, key=None):
super(LayerChoice, self).__init__() super(LayerChoice, self).__init__()
self.candidate_ops = op_candidates self.op_candidates = op_candidates
self.label = key self.label = key if key is not None else f'layerchoice_{uid()}'
self.key = key # deprecated, for backward compatibility self.key = self.label # deprecated, for backward compatibility
for i, module in enumerate(op_candidates): # deprecated, for backward compatibility for i, module in enumerate(op_candidates): # deprecated, for backward compatibility
self.add_module(str(i), module) self.add_module(str(i), module)
if reduction or return_mask: if reduction or return_mask:
...@@ -66,8 +62,8 @@ class InputChoice(nn.Module): ...@@ -66,8 +62,8 @@ class InputChoice(nn.Module):
self.n_candidates = n_candidates self.n_candidates = n_candidates
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction self.reduction = reduction
self.label = key self.label = key if key is not None else f'inputchoice_{uid()}'
self.key = key # deprecated, for backward compatibility self.key = self.label # deprecated, for backward compatibility
if choose_from or return_mask: if choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!') _logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
...@@ -101,6 +97,7 @@ class Placeholder(nn.Module): ...@@ -101,6 +97,7 @@ class Placeholder(nn.Module):
class ChosenInputs(nn.Module): class ChosenInputs(nn.Module):
""" """
""" """
def __init__(self, chosen: List[int], reduction: str): def __init__(self, chosen: List[int], reduction: str):
super().__init__() super().__init__()
self.chosen = chosen self.chosen = chosen
...@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module): ...@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules # the following are pytorch modules
class Module(nn.Module): Module = nn.Module
def __init__(self):
super(Module, self).__init__()
class Sequential(nn.Sequential): class Sequential(nn.Sequential):
...@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList): ...@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList):
super(ModuleList, self).__init__(*args) super(ModuleList, self).__init__(*args)
def wrap_module(original_class): Identity = blackbox_module(nn.Identity)
orig_init = original_class.__init__ Linear = blackbox_module(nn.Linear)
argname_list = list(inspect.signature(original_class).parameters.keys()) Conv1d = blackbox_module(nn.Conv1d)
# Make copy of original __init__, so we can call it without recursion Conv2d = blackbox_module(nn.Conv2d)
Conv3d = blackbox_module(nn.Conv3d)
def __init__(self, *args, **kws): ConvTranspose1d = blackbox_module(nn.ConvTranspose1d)
full_args = {} ConvTranspose2d = blackbox_module(nn.ConvTranspose2d)
full_args.update(kws) ConvTranspose3d = blackbox_module(nn.ConvTranspose3d)
for i, arg in enumerate(args): Threshold = blackbox_module(nn.Threshold)
full_args[argname_list[i]] = arg ReLU = blackbox_module(nn.ReLU)
add_record(id(self), full_args) Hardtanh = blackbox_module(nn.Hardtanh)
ReLU6 = blackbox_module(nn.ReLU6)
orig_init(self, *args, **kws) # Call the original __init__ Sigmoid = blackbox_module(nn.Sigmoid)
Tanh = blackbox_module(nn.Tanh)
original_class.__init__ = __init__ # Set the class' __init__ to the new one Softmax = blackbox_module(nn.Softmax)
return original_class Softmax2d = blackbox_module(nn.Softmax2d)
LogSoftmax = blackbox_module(nn.LogSoftmax)
ELU = blackbox_module(nn.ELU)
Identity = wrap_module(nn.Identity) SELU = blackbox_module(nn.SELU)
Linear = wrap_module(nn.Linear) CELU = blackbox_module(nn.CELU)
Conv1d = wrap_module(nn.Conv1d) GLU = blackbox_module(nn.GLU)
Conv2d = wrap_module(nn.Conv2d) GELU = blackbox_module(nn.GELU)
Conv3d = wrap_module(nn.Conv3d) Hardshrink = blackbox_module(nn.Hardshrink)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d) LeakyReLU = blackbox_module(nn.LeakyReLU)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d) LogSigmoid = blackbox_module(nn.LogSigmoid)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d) Softplus = blackbox_module(nn.Softplus)
Threshold = wrap_module(nn.Threshold) Softshrink = blackbox_module(nn.Softshrink)
ReLU = wrap_module(nn.ReLU) MultiheadAttention = blackbox_module(nn.MultiheadAttention)
Hardtanh = wrap_module(nn.Hardtanh) PReLU = blackbox_module(nn.PReLU)
ReLU6 = wrap_module(nn.ReLU6) Softsign = blackbox_module(nn.Softsign)
Sigmoid = wrap_module(nn.Sigmoid) Softmin = blackbox_module(nn.Softmin)
Tanh = wrap_module(nn.Tanh) Tanhshrink = blackbox_module(nn.Tanhshrink)
Softmax = wrap_module(nn.Softmax) RReLU = blackbox_module(nn.RReLU)
Softmax2d = wrap_module(nn.Softmax2d) AvgPool1d = blackbox_module(nn.AvgPool1d)
LogSoftmax = wrap_module(nn.LogSoftmax) AvgPool2d = blackbox_module(nn.AvgPool2d)
ELU = wrap_module(nn.ELU) AvgPool3d = blackbox_module(nn.AvgPool3d)
SELU = wrap_module(nn.SELU) MaxPool1d = blackbox_module(nn.MaxPool1d)
CELU = wrap_module(nn.CELU) MaxPool2d = blackbox_module(nn.MaxPool2d)
GLU = wrap_module(nn.GLU) MaxPool3d = blackbox_module(nn.MaxPool3d)
GELU = wrap_module(nn.GELU) MaxUnpool1d = blackbox_module(nn.MaxUnpool1d)
Hardshrink = wrap_module(nn.Hardshrink) MaxUnpool2d = blackbox_module(nn.MaxUnpool2d)
LeakyReLU = wrap_module(nn.LeakyReLU) MaxUnpool3d = blackbox_module(nn.MaxUnpool3d)
LogSigmoid = wrap_module(nn.LogSigmoid) FractionalMaxPool2d = blackbox_module(nn.FractionalMaxPool2d)
Softplus = wrap_module(nn.Softplus) FractionalMaxPool3d = blackbox_module(nn.FractionalMaxPool3d)
Softshrink = wrap_module(nn.Softshrink) LPPool1d = blackbox_module(nn.LPPool1d)
MultiheadAttention = wrap_module(nn.MultiheadAttention) LPPool2d = blackbox_module(nn.LPPool2d)
PReLU = wrap_module(nn.PReLU) LocalResponseNorm = blackbox_module(nn.LocalResponseNorm)
Softsign = wrap_module(nn.Softsign) BatchNorm1d = blackbox_module(nn.BatchNorm1d)
Softmin = wrap_module(nn.Softmin) BatchNorm2d = blackbox_module(nn.BatchNorm2d)
Tanhshrink = wrap_module(nn.Tanhshrink) BatchNorm3d = blackbox_module(nn.BatchNorm3d)
RReLU = wrap_module(nn.RReLU) InstanceNorm1d = blackbox_module(nn.InstanceNorm1d)
AvgPool1d = wrap_module(nn.AvgPool1d) InstanceNorm2d = blackbox_module(nn.InstanceNorm2d)
AvgPool2d = wrap_module(nn.AvgPool2d) InstanceNorm3d = blackbox_module(nn.InstanceNorm3d)
AvgPool3d = wrap_module(nn.AvgPool3d) LayerNorm = blackbox_module(nn.LayerNorm)
MaxPool1d = wrap_module(nn.MaxPool1d) GroupNorm = blackbox_module(nn.GroupNorm)
MaxPool2d = wrap_module(nn.MaxPool2d) SyncBatchNorm = blackbox_module(nn.SyncBatchNorm)
MaxPool3d = wrap_module(nn.MaxPool3d) Dropout = blackbox_module(nn.Dropout)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d) Dropout2d = blackbox_module(nn.Dropout2d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d) Dropout3d = blackbox_module(nn.Dropout3d)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d) AlphaDropout = blackbox_module(nn.AlphaDropout)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d) FeatureAlphaDropout = blackbox_module(nn.FeatureAlphaDropout)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d) ReflectionPad1d = blackbox_module(nn.ReflectionPad1d)
LPPool1d = wrap_module(nn.LPPool1d) ReflectionPad2d = blackbox_module(nn.ReflectionPad2d)
LPPool2d = wrap_module(nn.LPPool2d) ReplicationPad2d = blackbox_module(nn.ReplicationPad2d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm) ReplicationPad1d = blackbox_module(nn.ReplicationPad1d)
BatchNorm1d = wrap_module(nn.BatchNorm1d) ReplicationPad3d = blackbox_module(nn.ReplicationPad3d)
BatchNorm2d = wrap_module(nn.BatchNorm2d) CrossMapLRN2d = blackbox_module(nn.CrossMapLRN2d)
BatchNorm3d = wrap_module(nn.BatchNorm3d) Embedding = blackbox_module(nn.Embedding)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d) EmbeddingBag = blackbox_module(nn.EmbeddingBag)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d) RNNBase = blackbox_module(nn.RNNBase)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d) RNN = blackbox_module(nn.RNN)
LayerNorm = wrap_module(nn.LayerNorm) LSTM = blackbox_module(nn.LSTM)
GroupNorm = wrap_module(nn.GroupNorm) GRU = blackbox_module(nn.GRU)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm) RNNCellBase = blackbox_module(nn.RNNCellBase)
Dropout = wrap_module(nn.Dropout) RNNCell = blackbox_module(nn.RNNCell)
Dropout2d = wrap_module(nn.Dropout2d) LSTMCell = blackbox_module(nn.LSTMCell)
Dropout3d = wrap_module(nn.Dropout3d) GRUCell = blackbox_module(nn.GRUCell)
AlphaDropout = wrap_module(nn.AlphaDropout) PixelShuffle = blackbox_module(nn.PixelShuffle)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout) Upsample = blackbox_module(nn.Upsample)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d) UpsamplingNearest2d = blackbox_module(nn.UpsamplingNearest2d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d) UpsamplingBilinear2d = blackbox_module(nn.UpsamplingBilinear2d)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d) PairwiseDistance = blackbox_module(nn.PairwiseDistance)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d) AdaptiveMaxPool1d = blackbox_module(nn.AdaptiveMaxPool1d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d) AdaptiveMaxPool2d = blackbox_module(nn.AdaptiveMaxPool2d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d) AdaptiveMaxPool3d = blackbox_module(nn.AdaptiveMaxPool3d)
Embedding = wrap_module(nn.Embedding) AdaptiveAvgPool1d = blackbox_module(nn.AdaptiveAvgPool1d)
EmbeddingBag = wrap_module(nn.EmbeddingBag) AdaptiveAvgPool2d = blackbox_module(nn.AdaptiveAvgPool2d)
RNNBase = wrap_module(nn.RNNBase) AdaptiveAvgPool3d = blackbox_module(nn.AdaptiveAvgPool3d)
RNN = wrap_module(nn.RNN) TripletMarginLoss = blackbox_module(nn.TripletMarginLoss)
LSTM = wrap_module(nn.LSTM) ZeroPad2d = blackbox_module(nn.ZeroPad2d)
GRU = wrap_module(nn.GRU) ConstantPad1d = blackbox_module(nn.ConstantPad1d)
RNNCellBase = wrap_module(nn.RNNCellBase) ConstantPad2d = blackbox_module(nn.ConstantPad2d)
RNNCell = wrap_module(nn.RNNCell) ConstantPad3d = blackbox_module(nn.ConstantPad3d)
LSTMCell = wrap_module(nn.LSTMCell) Bilinear = blackbox_module(nn.Bilinear)
GRUCell = wrap_module(nn.GRUCell) CosineSimilarity = blackbox_module(nn.CosineSimilarity)
PixelShuffle = wrap_module(nn.PixelShuffle) Unfold = blackbox_module(nn.Unfold)
Upsample = wrap_module(nn.Upsample) Fold = blackbox_module(nn.Fold)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d) AdaptiveLogSoftmaxWithLoss = blackbox_module(nn.AdaptiveLogSoftmaxWithLoss)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d) TransformerEncoder = blackbox_module(nn.TransformerEncoder)
PairwiseDistance = wrap_module(nn.PairwiseDistance) TransformerDecoder = blackbox_module(nn.TransformerDecoder)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d) TransformerEncoderLayer = blackbox_module(nn.TransformerEncoderLayer)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d) TransformerDecoderLayer = blackbox_module(nn.TransformerDecoderLayer)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d) Transformer = blackbox_module(nn.Transformer)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d) Flatten = blackbox_module(nn.Flatten)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d) Hardsigmoid = blackbox_module(nn.Hardsigmoid)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = wrap_module(nn.TripletMarginLoss)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d)
ConstantPad2d = wrap_module(nn.ConstantPad2d)
ConstantPad3d = wrap_module(nn.ConstantPad3d)
Bilinear = wrap_module(nn.Bilinear)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
Flatten = wrap_module(nn.Flatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'): if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = wrap_module(nn.Hardswish) Hardswish = blackbox_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'): if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = wrap_module(nn.SiLU) SiLU = blackbox_module(nn.SiLU)
Unflatten = wrap_module(nn.Unflatten) Unflatten = blackbox_module(nn.Unflatten)
TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss) TripletMarginWithDistanceLoss = blackbox_module(nn.TripletMarginWithDistanceLoss)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
\ No newline at end of file
...@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any: ...@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
return None return None
@register_trainer() @register_trainer
class PyTorchImageClassificationTrainer(BaseTrainer): class PyTorchImageClassificationTrainer(BaseTrainer):
""" """
Image classification trainer for PyTorch. Image classification trainer for PyTorch.
...@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer): ...@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently, Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful. only the key ``max_epochs`` is useful.
""" """
super(PyTorchImageClassificationTrainer, self).__init__() super().__init__()
self._use_cuda = torch.cuda.is_available() self._use_cuda = torch.cuda.is_available()
self.model = model self.model = model
if self._use_cuda: if self._use_cuda:
......
import inspect import inspect
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
...@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any: ...@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any:
module = __import__(path, globals(), locals(), [identifier]) module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier) return getattr(module, identifier)
def version_larger_equal(a: str, b: str) -> bool: def version_larger_equal(a: str, b: str) -> bool:
# TODO: refactor later # TODO: refactor later
a = a.split('+')[0] a = a.split('+')[0]
b = b.split('+')[0] b = b.split('+')[0]
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.'))) return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
_records = {} _records = {}
...@@ -29,73 +32,87 @@ def add_record(key, value): ...@@ -29,73 +32,87 @@ def add_record(key, value):
""" """
global _records global _records
if _records is not None: if _records is not None:
#assert key not in _records, '{} already in _records'.format(key) assert key not in _records, '{} already in _records'.format(key)
_records[key] = value _records[key] = value
def _register_module(original_class): def del_record(key):
orig_init = original_class.__init__ global _records
argname_list = list(inspect.signature(original_class).parameters.keys()) if _records is not None:
# Make copy of original __init__, so we can call it without recursion _records.pop(key, None)
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__ def _blackbox_cls(cls, module_name, register_format=None):
class wrapper(cls):
def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls).parameters.keys())
full_args = {}
full_args.update(kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
return original_class for argname, value in zip(argname_list, args):
full_args[argname] = value
# eject un-serializable arguments
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases.
if not isinstance(full_args[k], (int, float, str, dict, list)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
This is not supported. You can ignore this warning if you are passing the model to trainer.')
full_args.pop(k)
def register_module(): if register_format == 'args':
""" add_record(id(self), full_args)
Register a module. elif register_format == 'full':
""" full_class_name = cls.__module__ + '.' + cls.__name__
# use it as a decorator: @register_module() add_record(id(self), {'modulename': full_class_name, 'args': full_args})
def _register(cls):
m = _register_module(
original_class=cls)
return m
return _register super().__init__(*args, **kwargs)
def __del__(self):
del_record(id(self))
def _register_trainer(original_class): # using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
orig_init = original_class.__init__ # instead of simply putting torch.nn or etc.
argname_list = list(inspect.signature(original_class).parameters.keys()) wrapper.__module__ = module_name
# Make copy of original __init__, so we can call it without recursion wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
full_class_name = original_class.__module__ + '.' + original_class.__name__ return wrapper
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
# TODO: support both pytorch and tensorflow
from .nn.pytorch import Module
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__ def blackbox(cls, *args, **kwargs):
"""
To create an blackbox instance inline without decorator. For example,
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')(*args, **kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def blackbox_module(cls):
"""
Register a module. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')
def register_trainer():
def _register(cls):
m = _register_trainer(
original_class=cls)
return m
return _register def register_trainer(cls):
"""
Register a trainer. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'full')
_last_uid = defaultdict(int) _last_uid = defaultdict(int)
......
...@@ -5,6 +5,7 @@ tuner_result.txt ...@@ -5,6 +5,7 @@ tuner_result.txt
assessor_result.txt assessor_result.txt
_generated_model.py _generated_model.py
_generated_model_*.py
data data
generated generated
...@@ -7,9 +7,9 @@ import torch.nn as torch_nn ...@@ -7,9 +7,9 @@ import torch.nn as torch_nn
import ops import ops
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module from nni.retiarii import blackbox_module
@blackbox_module
class AuxiliaryHead(nn.Module): class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """ """ Auxiliary head in 2/3 place of network to let the gradient flow well """
...@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module): ...@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module):
logits = self.linear(out) logits = self.linear(out)
return logits return logits
@register_module()
class Node(nn.Module): class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect): def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__() super().__init__()
...@@ -66,7 +65,6 @@ class Node(nn.Module): ...@@ -66,7 +65,6 @@ class Node(nn.Module):
#out = [self.drop_path(o) if o is not None else None for o in out] #out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out) return self.input_switch(out)
@register_module()
class Cell(nn.Module): class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
...@@ -100,7 +98,6 @@ class Cell(nn.Module): ...@@ -100,7 +98,6 @@ class Cell(nn.Module):
output = torch.cat(new_tensors, dim=1) output = torch.cat(new_tensors, dim=1)
return output return output
@register_module()
class CNN(nn.Module): class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
......
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module from nni.retiarii import blackbox_module
@register_module() @blackbox_module
class DropPath(nn.Module): class DropPath(nn.Module):
def __init__(self, p=0.): def __init__(self, p=0.):
""" """
...@@ -12,7 +12,7 @@ class DropPath(nn.Module): ...@@ -12,7 +12,7 @@ class DropPath(nn.Module):
p : float p : float
Probability of an path to be zeroed. Probability of an path to be zeroed.
""" """
super(DropPath, self).__init__() super().__init__()
self.p = p self.p = p
def forward(self, x): def forward(self, x):
...@@ -24,13 +24,13 @@ class DropPath(nn.Module): ...@@ -24,13 +24,13 @@ class DropPath(nn.Module):
return x return x
@register_module() @blackbox_module
class PoolBN(nn.Module): class PoolBN(nn.Module):
""" """
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`. AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
""" """
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super(PoolBN, self).__init__() super().__init__()
if pool_type.lower() == 'max': if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding) self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg': elif pool_type.lower() == 'avg':
...@@ -45,13 +45,13 @@ class PoolBN(nn.Module): ...@@ -45,13 +45,13 @@ class PoolBN(nn.Module):
out = self.bn(out) out = self.bn(out)
return out return out
@register_module() @blackbox_module
class StdConv(nn.Module): class StdConv(nn.Module):
""" """
Standard conv: ReLU - Conv - BN Standard conv: ReLU - Conv - BN
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(StdConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False), nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
...@@ -61,13 +61,13 @@ class StdConv(nn.Module): ...@@ -61,13 +61,13 @@ class StdConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class FacConv(nn.Module): class FacConv(nn.Module):
""" """
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
""" """
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super(FacConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False), nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
...@@ -78,7 +78,7 @@ class FacConv(nn.Module): ...@@ -78,7 +78,7 @@ class FacConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class DilConv(nn.Module): class DilConv(nn.Module):
""" """
(Dilated) depthwise separable conv. (Dilated) depthwise separable conv.
...@@ -86,7 +86,7 @@ class DilConv(nn.Module): ...@@ -86,7 +86,7 @@ class DilConv(nn.Module):
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field. If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
...@@ -98,14 +98,14 @@ class DilConv(nn.Module): ...@@ -98,14 +98,14 @@ class DilConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class SepConv(nn.Module): class SepConv(nn.Module):
""" """
Depthwise separable conv. Depthwise separable conv.
DilConv(dilation=1) * 2. DilConv(dilation=1) * 2.
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine), DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine) DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
...@@ -114,13 +114,13 @@ class SepConv(nn.Module): ...@@ -114,13 +114,13 @@ class SepConv(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
@register_module() @blackbox_module
class FactorizedReduce(nn.Module): class FactorizedReduce(nn.Module):
""" """
Reduce feature map size by factorized pointwise (stride=2). Reduce feature map size by factorized pointwise (stride=2).
""" """
def __init__(self, C_in, C_out, affine=True): def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__() super().__init__()
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
......
...@@ -13,10 +13,10 @@ from darts_model import CNN ...@@ -13,10 +13,10 @@ from darts_model import CNN
if __name__ == '__main__': if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8) base_model = CNN(32, 3, 16, 10, 8)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10", trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True}, dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32}, dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3}, optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1}) trainer_kwargs={"max_epochs": 1})
#simple_startegy = TPEStrategy() #simple_startegy = TPEStrategy()
simple_startegy = RandomStrategy() simple_startegy = RandomStrategy()
...@@ -31,4 +31,4 @@ if __name__ == '__main__': ...@@ -31,4 +31,4 @@ if __name__ == '__main__':
exp_config.training_service.use_active_gpu = True exp_config.training_service.use_active_gpu = True
exp_config.training_service.gpu_indices = [1, 2] exp_config.training_service.gpu_indices = [1, 2]
exp.run(exp_config, 8081, debug=True) exp.run(exp_config, 8081)
...@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0): ...@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0):
valid_transform = transforms.Compose(normalize) valid_transform = transforms.Compose(normalize)
if cls == "cifar10": if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) dataset_train = CIFAR10(root="./data/cifar10", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) dataset_valid = CIFAR10(root="./data/cifar10", train=False, download=True, transform=valid_transform)
else: else:
raise NotImplementedError raise NotImplementedError
return dataset_train, dataset_valid return dataset_train, dataset_valid
......
from nni.retiarii import blackbox_module
import nni.retiarii.nn.pytorch as nn
import warnings import warnings
import torch import torch
...@@ -8,8 +10,6 @@ import torch.nn.functional as F ...@@ -8,8 +10,6 @@ import torch.nn.functional as F
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2])) sys.path.append(str(Path(__file__).resolve().parents[2]))
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow. # 1.0 - tensorflow.
...@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module): ...@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) + x return self.net(x) + x
class _InvertedResidual(nn.Module): class _InvertedResidual(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, skip, bn_momentum=0.1): def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, skip, bn_momentum=0.1):
...@@ -110,7 +111,7 @@ def _get_depths(depths, alpha): ...@@ -110,7 +111,7 @@ def _get_depths(depths, alpha):
rather than down. """ rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
@register_module()
class MNASNet(nn.Module): class MNASNet(nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model. implements the B1 variant of the model.
...@@ -127,7 +128,7 @@ class MNASNet(nn.Module): ...@@ -127,7 +128,7 @@ class MNASNet(nn.Module):
def __init__(self, alpha, depths, convops, kernel_sizes, num_layers, def __init__(self, alpha, depths, convops, kernel_sizes, num_layers,
skips, num_classes=1000, dropout=0.2): skips, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__() super().__init__()
assert alpha > 0.0 assert alpha > 0.0
assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7 assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7
self.alpha = alpha self.alpha = alpha
...@@ -143,22 +144,22 @@ class MNASNet(nn.Module): ...@@ -143,22 +144,22 @@ class MNASNet(nn.Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
] ]
count = 0 count = 0
#for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \ # for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios): # zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides): for filter_size, exp_ratio, stride in zip(base_filter_sizes, exp_ratios, strides):
# TODO: restrict that "choose" can only be used within mutator # TODO: restrict that "choose" can only be used within mutator
ph = nn.Placeholder(label=f'mutable_{count}', related_info={ ph = nn.Placeholder(label=f'mutable_{count}', related_info={
'kernel_size_options': [1, 3, 5], 'kernel_size_options': [1, 3, 5],
'n_layer_options': [1, 2, 3, 4], 'n_layer_options': [1, 2, 3, 4],
'op_type_options': ['__mutated__.base_mnasnet.RegularConv', 'op_type_options': ['__mutated__.base_mnasnet.RegularConv',
'__mutated__.base_mnasnet.DepthwiseConv', '__mutated__.base_mnasnet.DepthwiseConv',
'__mutated__.base_mnasnet.MobileConv'], '__mutated__.base_mnasnet.MobileConv'],
#'se_ratio_options': [0, 0.25], # 'se_ratio_options': [0, 0.25],
'skip_options': ['identity', 'no'], 'skip_options': ['identity', 'no'],
'n_filter_options': [int(filter_size*x) for x in [0.75, 1.0, 1.25]], 'n_filter_options': [int(filter_size*x) for x in [0.75, 1.0, 1.25]],
'exp_ratio': exp_ratio, 'exp_ratio': exp_ratio,
'stride': stride, 'stride': stride,
'in_ch': depths[0] if count == 0 else None 'in_ch': depths[0] if count == 0 else None
}) })
layers.append(ph) layers.append(ph)
'''if conv == "mconv": '''if conv == "mconv":
...@@ -185,7 +186,7 @@ class MNASNet(nn.Module): ...@@ -185,7 +186,7 @@ class MNASNet(nn.Module):
#self.for_test = 10 #self.for_test = 10
def forward(self, x): def forward(self, x):
#if self.for_test == 10: # if self.for_test == 10:
x = self.layers(x) x = self.layers(x)
# Equivalent to global avgpool and removing H and W dimensions. # Equivalent to global avgpool and removing H and W dimensions.
x = x.mean([2, 3]) x = x.mean([2, 3])
...@@ -196,7 +197,7 @@ class MNASNet(nn.Module): ...@@ -196,7 +197,7 @@ class MNASNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
torch_nn.init.kaiming_normal_(m.weight, mode="fan_out", torch_nn.init.kaiming_normal_(m.weight, mode="fan_out",
nonlinearity="relu") nonlinearity="relu")
if m.bias is not None: if m.bias is not None:
torch_nn.init.zeros_(m.bias) torch_nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
...@@ -204,16 +205,18 @@ class MNASNet(nn.Module): ...@@ -204,16 +205,18 @@ class MNASNet(nn.Module):
torch_nn.init.zeros_(m.bias) torch_nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
torch_nn.init.kaiming_uniform_(m.weight, mode="fan_out", torch_nn.init.kaiming_uniform_(m.weight, mode="fan_out",
nonlinearity="sigmoid") nonlinearity="sigmoid")
torch_nn.init.zeros_(m.bias) torch_nn.init.zeros_(m.bias)
def test_model(model): def test_model(model):
model(torch.randn(2, 3, 224, 224)) model(torch.randn(2, 3, 224, 224))
#====================definition of candidate op classes
# ====================definition of candidate op classes
BN_MOMENTUM = 1 - 0.9997 BN_MOMENTUM = 1 - 0.9997
class RegularConv(nn.Module): class RegularConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride): def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__() super().__init__()
...@@ -234,6 +237,7 @@ class RegularConv(nn.Module): ...@@ -234,6 +237,7 @@ class RegularConv(nn.Module):
out = out + x out = out + x
return out return out
class DepthwiseConv(nn.Module): class DepthwiseConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride): def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__() super().__init__()
...@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module): ...@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module):
out = out + x out = out + x
return out return out
class MobileConv(nn.Module): class MobileConv(nn.Module):
def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride): def __init__(self, kernel_size, in_ch, out_ch, skip, exp_ratio, stride):
super().__init__() super().__init__()
...@@ -274,7 +279,7 @@ class MobileConv(nn.Module): ...@@ -274,7 +279,7 @@ class MobileConv(nn.Module):
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM), nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# Depthwise # Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding= (kernel_size - 1) // 2, nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=(kernel_size - 1) // 2,
stride=stride, groups=mid_ch, bias=False), stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM), nn.BatchNorm2d(mid_ch, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
...@@ -288,5 +293,6 @@ class MobileConv(nn.Module): ...@@ -288,5 +293,6 @@ class MobileConv(nn.Module):
out = out + x out = out + x
return out return out
# mnasnet0_5 # mnasnet0_5
ir_module = _InvertedResidual(16, 16, 3, 1, 1, True) ir_module = _InvertedResidual(16, 16, 3, 1, 1, True)
\ No newline at end of file
...@@ -19,12 +19,12 @@ if __name__ == '__main__': ...@@ -19,12 +19,12 @@ if __name__ == '__main__':
_DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1] _DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1]
base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES, base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES,
_DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS) _DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10", trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True}, dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32}, dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3}, optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1}) trainer_kwargs={"max_epochs": 1})
# new interface # new interface
applied_mutators = [] applied_mutators = []
...@@ -41,4 +41,4 @@ if __name__ == '__main__': ...@@ -41,4 +41,4 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10 exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081, debug=True) exp.run(exp_config, 8081)
import random
import nni.retiarii.nn.pytorch as nn
import torch.nn.functional as F
from nni.retiarii.experiment import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.strategies import RandomStrategy
from nni.retiarii.trainer import PyTorchImageClassificationTrainer
class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice([
nn.Linear(4*4*50, hidden_size),
nn.Linear(4*4*50, hidden_size, bias=False)
])
self.fc2 = nn.Linear(hidden_size, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
if __name__ == '__main__':
base_model = Net(128)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="MNIST",
dataset_kwargs={"root": "data/mnist", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
simple_startegy = RandomStrategy()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081 + random.randint(0, 100))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment