Unverified Commit 4784cc6c authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge pull request #3302 from microsoft/v2.0-merge

Merge branch v2.0 into master (no squash)
parents 25db55ca 349ead41
...@@ -147,7 +147,7 @@ class Mutator(BaseMutator): ...@@ -147,7 +147,7 @@ class Mutator(BaseMutator):
Parameters Parameters
---------- ----------
mutable : LayerChoice mutable : nni.nas.pytorch.mutables.LayerChoice
Layer choice module. Layer choice module.
args : list of torch.Tensor args : list of torch.Tensor
Inputs Inputs
...@@ -180,7 +180,7 @@ class Mutator(BaseMutator): ...@@ -180,7 +180,7 @@ class Mutator(BaseMutator):
Parameters Parameters
---------- ----------
mutable : InputChoice mutable : nni.nas.pytorch.mutables.InputChoice
Input choice module. Input choice module.
tensor_list : list of torch.Tensor tensor_list : list of torch.Tensor
Tensor list to apply the decision on. Tensor list to apply the decision on.
......
...@@ -2,4 +2,4 @@ from .operation import Operation ...@@ -2,4 +2,4 @@ from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
from .mutator import * from .mutator import *
from .utils import register_module from .utils import blackbox, blackbox_module, register_trainer
\ No newline at end of file
...@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str: ...@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
def _sorted_incoming_edges(node: Node) -> List[Edge]: def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node] edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: %s', str(edges)) _logger.debug('sorted_incoming_edges: %s', str(edges))
if not edges: if not edges:
return [] return []
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges])) _logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges): if all(edge.tail_slot is None for edge in edges):
return edges return edges
if all(isinstance(edge.tail_slot, int) for edge in edges): if all(isinstance(edge.tail_slot, int) for edge in edges):
......
...@@ -6,517 +6,501 @@ import torch ...@@ -6,517 +6,501 @@ import torch
from ..graph import Graph, Model, Node from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell from ..operation import Cell
from ..utils import get_records
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .utils import _convert_name, build_full_name from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
global_seq = 0
global_graph_id = 0
modules_arg = None
class GraphConverter:
def __init__(self):
self.global_seq = 0
self.global_graph_id = 0
self.modules_arg = get_records()
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False): def _add_edge(self, ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
""" """
Parameters Parameters
---------- ----------
ir_graph : Graph ir_graph : Graph
node : torch._C.Node node : torch._C.Node
graph_inputs : List[torch._C.Value] graph_inputs : List[torch._C.Value]
a list of a script graph's inputs a list of a script graph's inputs
node_index : Dict node_index : Dict
new_node : Node new_node : Node
newly created ir node corresponding to `node` newly created ir node corresponding to `node`
output_remap : Dict output_remap : Dict
ignore_first : bool ignore_first : bool
if it is true, skip the first input if it is true, skip the first input
""" """
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1 is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0 new_node_input_idx = 0
for _input in node.inputs(): for _input in node.inputs():
if ignore_first: if ignore_first:
ignore_first = False ignore_first = False
continue continue
# handle source node # handle source node
if _input in graph_inputs: if _input in graph_inputs:
idx = graph_inputs.index(_input) idx = graph_inputs.index(_input)
src_node = ir_graph.input_node src_node = ir_graph.input_node
src_node_idx = idx src_node_idx = idx
elif _input in output_remap: elif _input in output_remap:
assert output_remap[_input].kind() == 'aten::append' assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input] predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node) assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None src_node_idx = None
src_node = node_index[predecessor_node] src_node = node_index[predecessor_node]
assert isinstance(src_node, Node) assert isinstance(src_node, Node)
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs = [_output for _output in predecessor_node.outputs()]
if len(predecessor_outputs) == 1:
idx = None
else: else:
idx = predecessor_outputs.index(_input) predecessor_node = _input.node()
ir_predecessor_node = node_index[predecessor_node] assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = idx # find out the index of _input in the outputs of predecessor_node
assert isinstance(ir_predecessor_node, Node) predecessor_outputs = [_output for _output in predecessor_node.outputs()]
src_node = ir_predecessor_node if len(predecessor_outputs) == 1:
idx = None
# handle destination node else:
dst_node = new_node idx = predecessor_outputs.index(_input)
if is_single_input: ir_predecessor_node = node_index[predecessor_node]
dst_node_idx = None src_node_idx = idx
else: assert isinstance(ir_predecessor_node, Node)
dst_node_idx = new_node_input_idx src_node = ir_predecessor_node
# create edge # handle destination node
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx)) dst_node = new_node
if is_single_input:
new_node_input_idx += 1 dst_node_idx = None
else:
dst_node_idx = new_node_input_idx
def create_prim_constant_node(ir_graph, node, module_name):
global global_seq
attrs = {}
if node.outputsAt(0).toIValue() is not None:
attrs = {'value': node.outputsAt(0).toIValue()}
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, global_seq),
node.kind(), attrs)
return new_node
def handle_prim_attr_node(node):
assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs
def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(ir_graph, targeted_type=None): # create edge
""" ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
if targeted_type is None:
to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters new_node_input_idx += 1
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
Returns def create_prim_constant_node(self, ir_graph, node, module_name):
------- attrs = {}
dict if node.outputsAt(0).toIValue() is not None:
the mapping from graph node to our graph ir node attrs = {'value': node.outputsAt(0).toIValue()}
""" self.global_seq += 1
# handle inputs new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, self.global_seq),
graph_inputs = [] node.kind(), attrs)
for _input in sm_graph.inputs(): return new_node
if _input.debugName() == 'self':
assert _input.unique() == 0
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap = {}
def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls def handle_prim_attr_node(self, node):
assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs
NOTE: do not support dynamic graph def _remove_mangle(self, module_type_str):
""" return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition')
expr = _generate_expr(cond_tensor)
return eval(expr)
def handle_if_node(node): def remove_unconnected_nodes(self, ir_graph, targeted_type=None):
""" """
Parameters Parameters
---------- ----------
node : torch._C.Node ir_graph : Graph
the node from TorchScript graph our ir graph representation
targeted_type : str
Returns nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
------- ```None``` means removing all the nodes whose fanout is 0.
Node
the created node ir
""" """
# only deal with input of prim::If is constant or attribute for now # build index of outputs of Node(s)
# will support constant expression in future node_fanout = set()
inputs = [i for i in node.inputs()] for edge in ir_graph.edges:
assert len(inputs) == 1 if edge.head.id not in node_fanout:
cond = handle_if_condition(inputs[0]) node_fanout.add(edge.head.id)
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()] to_removes = []
assert len(blocks) == 2 for hidden_node in ir_graph.hidden_nodes:
last_block_node = None if hidden_node.id not in node_fanout:
for node in blocks[chosen_block].nodes(): assert isinstance(hidden_node, Node)
last_block_node = handle_single_node(node) if targeted_type is None:
return last_block_node to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
def handle_single_node(node): to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(self, script_module, sm_graph, module, module_name, ir_model, ir_graph):
""" """
Convert torch script node to our node ir, and build our graph ir
Parameters Parameters
---------- ----------
node : torch._C.Node script_module : torch.jit.RecursiveScriptModule
the node from TorchScript graph the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
Returns Returns
------- -------
Node dict
the created node ir the mapping from graph node to our graph ir node
""" """
global global_seq # handle inputs
if node.kind() == 'prim::CallMethod': graph_inputs = []
# get and handle the first input, which should be an nn.Module for _input in sm_graph.inputs():
assert node.hasAttribute('name') if _input.debugName() == 'self':
if node.s('name') == 'forward': assert _input.unique() == 0
# node.inputsAt(0).type() is <class 'torch._C.ClassType'> continue
submodule_type_str = _remove_mangle(node.inputsAt(0).type().str()) graph_inputs.append(_input)
submodule = node.inputsAt(0).node() # TODO: add scope name
assert submodule.kind() == 'prim::GetAttr' ir_graph._add_input(_convert_name(_input.debugName()))
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name') node_index = {} # graph node to graph ir node
if submodule.inputsAt(0).debugName() == 'self': # some node does not have output but it modifies a variable, for example aten::append
# module is usually instantiated in __init__. # %17 : Tensor[] = aten::append(%out.1, %16)
# when calling a module in forward, # %out.1 is updated, and %17 is None
# prim::GetAttr is used to obtain the module in torch script. # we add output to this type of node and connect it to the following node which uses %out.1
# therefore, we do this check for a module. example below: # key: tensor (%out.1), value: node (this node)
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self) output_remap = {}
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format( def handle_if_condition(cond_tensor):
submodule_name, script_module._modules.keys()) """
to calculate the condition, we only deal with the following op types by tracing back
submodule_full_name = build_full_name(module_name, submodule_name) `prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name], generate the expression using recursive calls
submodule_obj,
submodule_full_name, ir_model) NOTE: do not support dynamic graph
"""
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
else: else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self) raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition')
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8) expr = _generate_expr(cond_tensor)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4) return eval(expr)
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList def handle_if_node(node):
predecessor = submodule.inputsAt(0).node() """
assert predecessor.kind() == 'prim::GetAttr' Parameters
assert predecessor.hasAttribute('name') ----------
assert predecessor.inputsAt(0).debugName() == 'self' node : torch._C.Node
predecessor_name = predecessor.s('name') the node from TorchScript graph
# FIXME: exchange
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name]) Returns
predecessor_obj = getattr(module, predecessor_name) -------
submodule_obj = getattr(predecessor_obj, submodule_name) Node
subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name], the created node ir
submodule_obj, submodule_full_name, ir_model) """
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
return last_block_node
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if node.kind() == 'prim::CallMethod':
# get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name')
if node.s('name') == 'forward':
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = self._remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node()
assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name')
if submodule.inputsAt(0).debugName() == 'self':
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = self.convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
else: else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str())) # %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# TODO: match subgraph with maintained graphs # %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
# build cell if submodule.inputsAt(0).type().name() == 'ModuleList':
if subgraph is None: # handle ModuleList
# if we do not parse this module's graph, we create Node for this module predecessor = submodule.inputsAt(0).node()
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs) assert predecessor.kind() == 'prim::GetAttr'
if isinstance(submodule_obj, Placeholder): assert predecessor.hasAttribute('name')
subcell.update_label(submodule_obj.label) assert predecessor.inputsAt(0).debugName() == 'self'
elif isinstance(submodule_obj, (LayerChoice, InputChoice)): predecessor_name = predecessor.s('name')
subcell.update_label(sub_m_attrs['label']) # FIXME: exchange
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name])
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = self.convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
# TODO: match subgraph with maintained graphs
# build cell
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, (LayerChoice, InputChoice)):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
node_index[node] = subcell
# connect the cell into graph
self._add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else: else:
# Graph already created, create Cell for it raise RuntimeError('unsupported CallMethod {}'.format(node.s('name')))
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs) elif node.kind() == 'prim::CallFunction':
subcell = ir_graph.add_node(submodule_full_name, new_cell) func_type_str = self._remove_mangle(node.inputsAt(0).type().str())
node_index[node] = subcell func = node.inputsAt(0).node()
# connect the cell into graph assert func.kind() == 'prim::Constant'
_add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True) assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
self.global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
'{}.{}'.format(func_type_str, func_name))
node_index[node] = func_node
self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
new_node = self.create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, self.global_seq), node.kind())
node_index[node] = new_node
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
self.global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
self.global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = self.handle_prim_attr_node(node)
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, self.global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
else: else:
raise RuntimeError('unsupported CallMethod {}'.format(node.s('name'))) raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
elif node.kind() == 'prim::CallFunction':
func_type_str = _remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, global_seq),
'{}.{}'.format(func_type_str, func_name))
node_index[node] = func_node
_add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
new_node = create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, global_seq), node.kind())
node_index[node] = new_node
_add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = handle_prim_attr_node(node)
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
return node_index[node]
for node in sm_graph.nodes():
handle_single_node(node)
return node_index
def merge_aten_slices(ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes = []
has_slice_node = False
for node in ir_graph.hidden_nodes:
if node.operation.type == 'aten::slice':
has_slice_node = True
for pred in node.predecessors:
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
head_slice_nodes.append(node)
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges: return node_index[node]
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
for node in sm_graph.nodes():
handle_single_node(node)
def refine_graph(ir_graph): return node_index
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph)
def merge_aten_slices(self, ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes = []
has_slice_node = False
for node in ir_graph.hidden_nodes:
if node.operation.type == 'aten::slice':
has_slice_node = True
for pred in node.predecessors:
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
head_slice_nodes.append(node)
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
def refine_graph(self, ir_graph):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
self.merge_aten_slices(ir_graph)
def _handle_layerchoice(self, module):
m_attrs = {}
candidates = module.op_candidates
choices = []
for cand in candidates:
assert id(cand) in self.modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(self.modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': self.modules_arg[id(cand)]})
m_attrs[f'choices'] = choices
m_attrs['label'] = module.label
return m_attrs
def _handle_inputchoice(self, module):
m_attrs = {}
m_attrs['n_candidates'] = module.n_candidates
m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label
return m_attrs
def convert_module(self, script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
def _handle_layerchoice(module): Parameters
global modules_arg ----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
m_attrs = {} Returns
candidates = module.candidate_ops -------
choices = [] Graph
for cand in candidates: the built graph ir from module, ```None``` means do not further parse the module
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand)) dict
assert isinstance(modules_arg[id(cand)], dict) the input arguments of this module
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__ """
choices.append({'type': cand_type, 'parameters': modules_arg[id(cand)]})
m_attrs[f'choices'] = choices
m_attrs['label'] = module.label
return m_attrs
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
m_attrs = None
if original_type_name in MODULE_EXCEPT_LIST:
pass # do nothing
elif original_type_name == OpTypeName.LayerChoice:
m_attrs = self._handle_layerchoice(module)
elif original_type_name == OpTypeName.InputChoice:
m_attrs = self._handle_inputchoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = self.modules_arg[id(module)]
elif original_type_name in torch.nn.__dict__:
# this is a basic module from pytorch, no need to parse its graph
assert id(module) in self.modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = self.modules_arg[id(module)]
elif id(module) in self.modules_arg:
# this module is marked as blackbox, won't continue to parse
m_attrs = self.modules_arg[id(module)]
if m_attrs is not None:
return None, m_attrs
# handle TorchScript graph
sm_graph = script_module.graph
self.global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
# handle graph nodes
node_index = self.handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph)
# handle graph outputs
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
def _handle_inputchoice(module): self.refine_graph(ir_graph)
m_attrs = {}
m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label
return m_attrs
ir_graph._register()
def convert_module(script_module, module, module_name, ir_model): return ir_graph, {}
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns def convert_to_graph(script_module, module):
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
global global_graph_id
global modules_arg
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
if original_type_name == OpTypeName.LayerChoice:
m_attrs = _handle_layerchoice(module)
return None, m_attrs
if original_type_name == OpTypeName.InputChoice:
m_attrs = _handle_inputchoice(module)
return None, m_attrs
if original_type_name == OpTypeName.Placeholder:
m_attrs = modules_arg[id(module)]
return None, m_attrs
if original_type_name in torch.nn.__dict__ and original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
assert id(module) in modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = modules_arg[id(module)]
return None, m_attrs
# handle TorchScript graph
sm_graph = script_module.graph
global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=global_graph_id, name=module_name, _internal=True)
# handle graph nodes
node_index = handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph)
# handle graph outputs
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
refine_graph(ir_graph)
ir_graph._register()
if id(module) not in modules_arg:
raise RuntimeError(f'{original_type_name} arguments are not recorded, \
you might have forgotten to decorate this class with @register_module()')
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return ir_graph, {}
def convert_to_graph(script_module, module, recorded_modules_arg):
""" """
Convert module to our graph ir, i.e., build a ```Model``` type Convert module to our graph ir, i.e., build a ```Model``` type
...@@ -526,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg): ...@@ -526,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg):
the script module obtained with torch.jit.script the script module obtained with torch.jit.script
module : nn.Module module : nn.Module
the targeted module instance the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns Returns
-------
Model Model
the constructed IR model the constructed IR model
""" """
global modules_arg
modules_arg = recorded_modules_arg
model = Model(_internal=True) model = Model(_internal=True)
module_name = '_model' module_name = '_model'
convert_module(script_module, module, module_name, model) GraphConverter().convert_module(script_module, module, module_name, model)
return model return model
...@@ -30,6 +30,10 @@ BasicOpsPT = { ...@@ -30,6 +30,10 @@ BasicOpsPT = {
'aten::size': 'Size', 'aten::size': 'Size',
'aten::view': 'View', 'aten::view': 'View',
'aten::eq': 'Eq', 'aten::eq': 'Eq',
'aten::Bool': 'Bool',
'aten::empty': 'Empty',
'aten::zeros': 'Zeros',
'aten::chunk': 'Chunk',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4) 'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
} }
......
import time import time
import os
from typing import List
from ..graph import Model, ModelStatus from ..graph import Model, ModelStatus
from .base import BaseExecutionEngine from .interface import AbstractExecutionEngine
from .cgo_engine import CGOExecutionEngine
from .interface import AbstractExecutionEngine, WorkerInfo
from .listener import DefaultListener from .listener import DefaultListener
_execution_engine = None _execution_engine = None
_default_listener = None _default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener', __all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources'] 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec']
def set_execution_engine(engine) -> None:
global _execution_engine
if _execution_engine is None:
_execution_engine = engine
else:
raise RuntimeError('execution engine is already set')
def get_execution_engine() -> BaseExecutionEngine: def get_execution_engine() -> AbstractExecutionEngine:
""" """
Currently we assume the default execution engine is BaseExecutionEngine. Currently we assume the default execution engine is BaseExecutionEngine.
""" """
global _execution_engine global _execution_engine
if _execution_engine is None:
if os.environ.get('CGO') == 'true':
_execution_engine = CGOExecutionEngine()
else:
_execution_engine = BaseExecutionEngine()
return _execution_engine return _execution_engine
...@@ -51,6 +49,11 @@ def wait_models(*models: Model) -> None: ...@@ -51,6 +49,11 @@ def wait_models(*models: Model) -> None:
break break
def query_available_resources() -> List[WorkerInfo]: def query_available_resources() -> int:
listener = get_and_register_default_listener(get_execution_engine()) engine = get_execution_engine()
return listener.resources resources = engine.query_available_resource()
return resources if isinstance(resources, int) else len(resources)
def is_stopped_exec(model: Model) -> bool:
return model.status in (ModelStatus.Trained, ModelStatus.Failed)
import logging import logging
import os
import random
import string
from typing import Dict, Any, List from typing import Dict, Any, List
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData from ..graph import Model, ModelStatus, MetricData
from ..integration import send_trial, receive_trial_parameters, get_advisor from ..integration_api import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -29,7 +32,7 @@ class BaseGraphData: ...@@ -29,7 +32,7 @@ class BaseGraphData:
class BaseExecutionEngine(AbstractExecutionEngine): class BaseExecutionEngine(AbstractExecutionEngine):
""" """
The execution engine with no optimization at all. The execution engine with no optimization at all.
Resource management is yet to be implemented. Resource management is implemented in this class.
""" """
def __init__(self) -> None: def __init__(self) -> None:
...@@ -50,6 +53,8 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -50,6 +53,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self._running_models: Dict[int, Model] = dict() self._running_models: Dict[int, Model] = dict()
self.resources = 0
def submit_models(self, *models: Model) -> None: def submit_models(self, *models: Model) -> None:
for model in models: for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), data = BaseGraphData(codegen.model_to_pytorch_script(model),
...@@ -60,17 +65,14 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -60,17 +65,14 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self._listeners.append(listener) self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None: def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners: if self.resources <= 0:
_logger.warning('resources: %s', listener.resources) _logger.warning('There is no available resource, but trial is submitted.')
if not listener.has_available_resource(): self.resources -= 1
_logger.warning('There is no available resource, but trial is submitted.') _logger.info('on_resource_used: %d', self.resources)
listener.on_resource_used(1)
_logger.warning('on_resource_used: %s', listener.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None: def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners: self.resources += num_trials
listener.on_resource_available(1 * num_trials) _logger.info('on_resource_available: %d', self.resources)
_logger.warning('on_resource_available: %s', listener.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None: def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id] model = self._running_models[trial_id]
...@@ -93,8 +95,8 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -93,8 +95,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
for listener in self._listeners: for listener in self._listeners:
listener.on_metric(model, metrics) listener.on_metric(model, metrics)
def query_available_resource(self) -> List[WorkerInfo]: def query_available_resource(self) -> int:
raise NotImplementedError # move the method from listener to here? return self.resources
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
...@@ -102,9 +104,12 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -102,9 +104,12 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Initialize the model, hand it over to trainer. Initialize the model, hand it over to trainer.
""" """
graph_data = BaseGraphData.load(receive_trial_parameters()) graph_data = BaseGraphData.load(receive_trial_parameters())
with open('_generated_model.py', 'w') as f: random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model_{random_str}.py'
with open(file_name, 'w') as f:
f.write(graph_data.model_script) f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module) trainer_cls = utils.import_(graph_data.training_module)
model_cls = utils.import_('_generated_model._model') model_cls = utils.import_(f'_generated_model_{random_str}._model')
trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs) trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs)
trainer_instance.fit() trainer_instance.fit()
os.remove(file_name)
\ No newline at end of file
...@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple ...@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData from ..graph import Model, ModelStatus, MetricData
from ..integration import send_trial, receive_trial_parameters, get_advisor from ..integration_api import send_trial, receive_trial_parameters, get_advisor
from .logical_optimizer.logical_plan import LogicalPlan, PhysicalDevice from .logical_optimizer.logical_plan import LogicalPlan, PhysicalDevice
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
......
from abc import ABC, abstractmethod, abstractclassmethod from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, NewType, List from typing import Any, NewType, List, Union
from ..graph import Model, MetricData from ..graph import Model, MetricData
...@@ -59,13 +59,6 @@ class AbstractGraphListener(ABC): ...@@ -59,13 +59,6 @@ class AbstractGraphListener(ABC):
""" """
pass pass
@abstractmethod
def on_resource_available(self, resources: List[WorkerInfo]) -> None:
"""
Reports when a worker becomes idle.
"""
pass
class AbstractExecutionEngine(ABC): class AbstractExecutionEngine(ABC):
""" """
...@@ -109,7 +102,7 @@ class AbstractExecutionEngine(ABC): ...@@ -109,7 +102,7 @@ class AbstractExecutionEngine(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def query_available_resource(self) -> List[WorkerInfo]: def query_available_resource(self) -> Union[List[WorkerInfo], int]:
""" """
Returns information of all idle workers. Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers. If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
......
...@@ -3,11 +3,6 @@ from .interface import MetricData, AbstractGraphListener ...@@ -3,11 +3,6 @@ from .interface import MetricData, AbstractGraphListener
class DefaultListener(AbstractGraphListener): class DefaultListener(AbstractGraphListener):
def __init__(self):
self.resources: int = 0 # simply resource count
def has_available_resource(self) -> bool:
return self.resources > 0
def on_metric(self, model: Model, metric: MetricData) -> None: def on_metric(self, model: Model, metric: MetricData) -> None:
model.metric = metric model.metric = metric
...@@ -20,9 +15,3 @@ class DefaultListener(AbstractGraphListener): ...@@ -20,9 +15,3 @@ class DefaultListener(AbstractGraphListener):
model.status = ModelStatus.Trained model.status = ModelStatus.Trained
else: else:
model.status = ModelStatus.Failed model.status = ModelStatus.Failed
def on_resource_available(self, resources: int) -> None:
self.resources += resources
def on_resource_used(self, resources: int) -> None:
self.resources -= resources
import logging import logging
import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
...@@ -7,20 +6,24 @@ from subprocess import Popen ...@@ -7,20 +6,24 @@ from subprocess import Popen
from threading import Thread from threading import Thread
from typing import Any, Optional from typing import Any, Optional
from ..experiment import Experiment, TrainingServiceConfig, launcher, rest from ..experiment import Experiment, TrainingServiceConfig
from ..experiment.config.base import ConfigBase, PathLike from ..experiment.config.base import ConfigBase, PathLike
from ..experiment.config import util from ..experiment.config import util
from ..experiment.pipe import Pipe from ..experiment.pipe import Pipe
from .graph import Model from .graph import Model
from .utils import get_records from .utils import get_records
from .integration import RetiariiAdvisor from .integration import RetiariiAdvisor
from .converter import convert_to_graph from .converter import convert_to_graph
from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator
from .trainer.interface import BaseTrainer from .trainer.interface import BaseTrainer, BaseOneShotTrainer
from .strategies.strategy import BaseStrategy from .strategies.strategy import BaseStrategy
from .trainer.pytorch import DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer)
@dataclass(init=False) @dataclass(init=False)
class RetiariiExeConfig(ConfigBase): class RetiariiExeConfig(ConfigBase):
...@@ -43,7 +46,7 @@ class RetiariiExeConfig(ConfigBase): ...@@ -43,7 +46,7 @@ class RetiariiExeConfig(ConfigBase):
super().__init__(**kwargs) super().__init__(**kwargs)
if training_service_platform is not None: if training_service_platform is not None:
assert 'training_service' not in kwargs assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(training_service_platform) self.training_service = util.training_service_config_factory(platform = training_service_platform)
def validate(self, initialized_tuner: bool = False) -> None: def validate(self, initialized_tuner: bool = False) -> None:
super().validate() super().validate()
...@@ -76,7 +79,7 @@ _validation_rules = { ...@@ -76,7 +79,7 @@ _validation_rules = {
class RetiariiExperiment(Experiment): class RetiariiExperiment(Experiment):
def __init__(self, base_model: Model, trainer: BaseTrainer, def __init__(self, base_model: Model, trainer: BaseTrainer,
applied_mutators: Mutator, strategy: BaseStrategy): applied_mutators: Mutator = None, strategy: BaseStrategy = None):
self.config: RetiariiExeConfig = None self.config: RetiariiExeConfig = None
self.port: Optional[int] = None self.port: Optional[int] = None
...@@ -87,6 +90,7 @@ class RetiariiExperiment(Experiment): ...@@ -87,6 +90,7 @@ class RetiariiExperiment(Experiment):
self.recorded_module_args = get_records() self.recorded_module_args = get_records()
self._dispatcher = RetiariiAdvisor() self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None self._pipe: Optional[Pipe] = None
...@@ -103,7 +107,10 @@ class RetiariiExperiment(Experiment): ...@@ -103,7 +107,10 @@ class RetiariiExperiment(Experiment):
mutator = LayerChoiceMutator(node.name, node.operation.parameters['choices']) mutator = LayerChoiceMutator(node.name, node.operation.parameters['choices'])
applied_mutators.append(mutator) applied_mutators.append(mutator)
for node in ic_nodes: for node in ic_nodes:
mutator = InputChoiceMutator(node.name, node.operation.parameters['n_chosen']) mutator = InputChoiceMutator(node.name,
node.operation.parameters['n_candidates'],
node.operation.parameters['n_chosen'],
node.operation.parameters['reduction'])
applied_mutators.append(mutator) applied_mutators.append(mutator)
return applied_mutators return applied_mutators
...@@ -114,14 +121,17 @@ class RetiariiExperiment(Experiment): ...@@ -114,14 +121,17 @@ class RetiariiExperiment(Experiment):
except Exception as e: except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e raise e
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args) base_model_ir = convert_to_graph(script_module, self.base_model)
assert id(self.trainer) in self.recorded_module_args recorded_module_args = get_records()
trainer_config = self.recorded_module_args[id(self.trainer)] if id(self.trainer) not in recorded_module_args:
base_model.apply_trainer(trainer_config['modulename'], trainer_config['args']) raise KeyError('Your trainer is not found in registered classes. You might have forgotten to \
register your customized trainer with @register_trainer decorator.')
trainer_config = recorded_module_args[id(self.trainer)]
base_model_ir.apply_trainer(trainer_config['modulename'], trainer_config['args'])
# handle inline mutations # handle inline mutations
mutators = self._process_inline_mutation(base_model) mutators = self._process_inline_mutation(base_model_ir)
if mutators is not None and self.applied_mutators: if mutators is not None and self.applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \ raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, \
do not use mutators when you use LayerChoice/InputChoice') do not use mutators when you use LayerChoice/InputChoice')
...@@ -129,10 +139,10 @@ class RetiariiExperiment(Experiment): ...@@ -129,10 +139,10 @@ class RetiariiExperiment(Experiment):
self.applied_mutators = mutators self.applied_mutators = mutators
_logger.info('Starting strategy...') _logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators)).start() Thread(target=self.strategy.run, args=(base_model_ir, self.applied_mutators)).start()
_logger.info('Strategy started!') _logger.info('Strategy started!')
def start(self, config: RetiariiExeConfig, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False) -> None:
""" """
Start the experiment in background. Start the experiment in background.
This method will raise exception on failure. This method will raise exception on failure.
...@@ -144,54 +154,37 @@ class RetiariiExperiment(Experiment): ...@@ -144,54 +154,37 @@ class RetiariiExperiment(Experiment):
debug debug
Whether to start in debug mode. Whether to start in debug mode.
""" """
# FIXME: super().start(port, debug)
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start()
self._start_strategy() self._start_strategy()
# TODO: register experiment management metadata def _create_dispatcher(self):
return self._dispatcher
def stop(self) -> None: def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
"""
Stop background experiment.
"""
self._proc.kill()
self._pipe.close()
self.port = None
self._proc = None
self._pipe = None
def run(self, config: RetiariiExeConfig, port: int = 8080, debug: bool = False) -> str:
""" """
Run the experiment. Run the experiment.
This function will block until experiment finish or error. This function will block until experiment finish or error.
""" """
self.config = config if isinstance(self.trainer, OneShotTrainers):
self.start(config, port, debug) self.trainer.fit()
try: else:
while True: assert config is not None, 'You are using classic search mode, config cannot be None!'
time.sleep(10) self.config = config
status = self.get_status() super().run(port, debug)
# TODO: double check the status
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']: def export_top_models(self, top_n: int = 1):
return status """
finally: export several top performing models
self.stop() """
if top_n != 1:
def get_status(self) -> str: _logger.warning('Only support top_n is 1 for now.')
if self.port is None: if isinstance(self.trainer, BaseOneShotTrainer):
raise RuntimeError('Experiment is not running') return self.trainer.export()
resp = rest.get(self.port, '/check-status') else:
return resp['status'] _logger.info('For this experiment, you can find out the best one from WebUI.')
def retrain_model(self, model):
"""
this function retrains the exported model, and test it to output test accuracy
"""
raise NotImplementedError
...@@ -594,10 +594,10 @@ class Edge: ...@@ -594,10 +594,10 @@ class Edge:
Example forward code snippet: Example forward code snippet:
``` ```
a, b, c = split(x) a, b, c = split(x)
p = concat(a, c) p = concat(a, c)
q = sum(b, p) q = sum(b, p)
z = relu(q) z = relu(q)
``` ```
Edges in above snippet: Edges in above snippet:
......
import logging import logging
import os
from typing import Any, Callable from typing import Any, Callable
import json_tricks import json_tricks
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType from nni.utils import MetricType
from .graph import MetricData from .graph import MetricData
from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
from .execution.api import set_execution_engine
from .integration_api import register_advisor
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -55,6 +59,15 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -55,6 +59,15 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.parameters_count = 0 self.parameters_count = 0
engine = self._create_execution_engine()
set_execution_engine(engine)
def _create_execution_engine(self):
if os.environ.get('CGO') == 'true':
return CGOExecutionEngine()
else:
return BaseExecutionEngine()
def handle_initialize(self, data): def handle_initialize(self, data):
"""callback for initializing the advisor """callback for initializing the advisor
Parameters Parameters
...@@ -126,34 +139,3 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -126,34 +139,3 @@ class RetiariiAdvisor(MsgDispatcherBase):
else: else:
return value return value
return value return value
_advisor: RetiariiAdvisor = None
def get_advisor() -> RetiariiAdvisor:
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor: RetiariiAdvisor):
global _advisor
assert _advisor is None
_advisor = advisor
def send_trial(parameters: dict) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
"""
params = nni.get_next_parameter()
return params
from typing import NewType, Any
import nni
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
_advisor: 'RetiariiAdvisor' = None
def get_advisor() -> 'RetiariiAdvisor':
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor: 'RetiariiAdvisor'):
global _advisor
assert _advisor is None
_advisor = advisor
def send_trial(parameters: dict) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
"""
params = nni.get_next_parameter()
return params
...@@ -28,8 +28,10 @@ class Mutator: ...@@ -28,8 +28,10 @@ class Mutator:
""" """
Mutates graphs in model to generate new model. Mutates graphs in model to generate new model.
`Mutator` class will be used in two places: `Mutator` class will be used in two places:
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy. 1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`. In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass, In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
and then use `Mutator.apply()` to mutate model. and then use `Mutator.apply()` to mutate model.
...@@ -104,6 +106,7 @@ class _RecorderSampler(Sampler): ...@@ -104,6 +106,7 @@ class _RecorderSampler(Sampler):
self.recorded_candidates.append(candidates) self.recorded_candidates.append(candidates)
return candidates[0] return candidates[0]
# the following is for inline mutation # the following is for inline mutation
...@@ -122,14 +125,16 @@ class LayerChoiceMutator(Mutator): ...@@ -122,14 +125,16 @@ class LayerChoiceMutator(Mutator):
class InputChoiceMutator(Mutator): class InputChoiceMutator(Mutator):
def __init__(self, node_name: str, n_chosen: int): def __init__(self, node_name: str, n_candidates: int, n_chosen: int, reduction: str):
super().__init__() super().__init__()
self.node_name = node_name self.node_name = node_name
self.n_candidates = n_candidates
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction
def mutate(self, model): def mutate(self, model):
target = model.get_node_by_name(self.node_name) target = model.get_node_by_name(self.node_name)
candidates = [i for i in range(self.n_chosen)] candidates = [i for i in range(self.n_candidates)]
chosen = self.choice(candidates) chosen = [self.choice(candidates) for _ in range(self.n_chosen)]
target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs', target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs',
{'chosen': chosen}) {'chosen': chosen, 'reduction': self.reduction})
import inspect
import logging import logging
from typing import Any, List from typing import Any, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import add_record from ...utils import add_record, blackbox_module, uid, version_larger_equal
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# NOTE: support pytorch version >= 1.5.0
__all__ = [ __all__ = [
'LayerChoice', 'InputChoice', 'Placeholder', 'LayerChoice', 'InputChoice', 'Placeholder',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict', 'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
...@@ -29,18 +30,24 @@ __all__ = [ ...@@ -29,18 +30,24 @@ __all__ = [
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold', 'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder', 'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'Flatten', 'Hardsigmoid'
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten', 'Hardsigmoid', 'Hardswish'
] ]
if version_larger_equal(torch.__version__, '1.6.0'):
__all__.append('Hardswish')
if version_larger_equal(torch.__version__, '1.7.0'):
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss'])
class LayerChoice(nn.Module): class LayerChoice(nn.Module):
def __init__(self, op_candidates, reduction=None, return_mask=False, key=None): def __init__(self, op_candidates, reduction=None, return_mask=False, key=None):
super(LayerChoice, self).__init__() super(LayerChoice, self).__init__()
self.candidate_ops = op_candidates self.op_candidates = op_candidates
self.label = key self.label = key if key is not None else f'layerchoice_{uid()}'
self.key = self.label # deprecated, for backward compatibility
for i, module in enumerate(op_candidates): # deprecated, for backward compatibility
self.add_module(str(i), module)
if reduction or return_mask: if reduction or return_mask:
_logger.warning('input arguments `reduction` and `return_mask` are deprecated!') _logger.warning('input arguments `reduction` and `return_mask` are deprecated!')
...@@ -52,10 +59,12 @@ class InputChoice(nn.Module): ...@@ -52,10 +59,12 @@ class InputChoice(nn.Module):
def __init__(self, n_candidates=None, choose_from=None, n_chosen=1, def __init__(self, n_candidates=None, choose_from=None, n_chosen=1,
reduction="sum", return_mask=False, key=None): reduction="sum", return_mask=False, key=None):
super(InputChoice, self).__init__() super(InputChoice, self).__init__()
self.n_candidates = n_candidates
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction self.reduction = reduction
self.label = key self.label = key if key is not None else f'inputchoice_{uid()}'
if n_candidates or choose_from or return_mask: self.key = self.label # deprecated, for backward compatibility
if choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!') _logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor: def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor:
...@@ -86,20 +95,37 @@ class Placeholder(nn.Module): ...@@ -86,20 +95,37 @@ class Placeholder(nn.Module):
class ChosenInputs(nn.Module): class ChosenInputs(nn.Module):
def __init__(self, chosen: int): """
"""
def __init__(self, chosen: List[int], reduction: str):
super().__init__() super().__init__()
self.chosen = chosen self.chosen = chosen
self.reduction = reduction
def forward(self, candidate_inputs): def forward(self, candidate_inputs):
# TODO: support multiple chosen inputs return self._tensor_reduction(self.reduction, [candidate_inputs[i] for i in self.chosen])
return candidate_inputs[self.chosen]
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
# the following are pytorch modules # the following are pytorch modules
class Module(nn.Module): Module = nn.Module
def __init__(self):
super(Module, self).__init__()
class Sequential(nn.Sequential): class Sequential(nn.Sequential):
...@@ -114,139 +140,116 @@ class ModuleList(nn.ModuleList): ...@@ -114,139 +140,116 @@ class ModuleList(nn.ModuleList):
super(ModuleList, self).__init__(*args) super(ModuleList, self).__init__(*args)
def wrap_module(original_class): Identity = blackbox_module(nn.Identity)
orig_init = original_class.__init__ Linear = blackbox_module(nn.Linear)
argname_list = list(inspect.signature(original_class).parameters.keys()) Conv1d = blackbox_module(nn.Conv1d)
# Make copy of original __init__, so we can call it without recursion Conv2d = blackbox_module(nn.Conv2d)
Conv3d = blackbox_module(nn.Conv3d)
def __init__(self, *args, **kws): ConvTranspose1d = blackbox_module(nn.ConvTranspose1d)
full_args = {} ConvTranspose2d = blackbox_module(nn.ConvTranspose2d)
full_args.update(kws) ConvTranspose3d = blackbox_module(nn.ConvTranspose3d)
for i, arg in enumerate(args): Threshold = blackbox_module(nn.Threshold)
full_args[argname_list[i]] = arg ReLU = blackbox_module(nn.ReLU)
add_record(id(self), full_args) Hardtanh = blackbox_module(nn.Hardtanh)
ReLU6 = blackbox_module(nn.ReLU6)
orig_init(self, *args, **kws) # Call the original __init__ Sigmoid = blackbox_module(nn.Sigmoid)
Tanh = blackbox_module(nn.Tanh)
original_class.__init__ = __init__ # Set the class' __init__ to the new one Softmax = blackbox_module(nn.Softmax)
return original_class Softmax2d = blackbox_module(nn.Softmax2d)
LogSoftmax = blackbox_module(nn.LogSoftmax)
ELU = blackbox_module(nn.ELU)
# TODO: support different versions of pytorch SELU = blackbox_module(nn.SELU)
Identity = wrap_module(nn.Identity) CELU = blackbox_module(nn.CELU)
Linear = wrap_module(nn.Linear) GLU = blackbox_module(nn.GLU)
Conv1d = wrap_module(nn.Conv1d) GELU = blackbox_module(nn.GELU)
Conv2d = wrap_module(nn.Conv2d) Hardshrink = blackbox_module(nn.Hardshrink)
Conv3d = wrap_module(nn.Conv3d) LeakyReLU = blackbox_module(nn.LeakyReLU)
ConvTranspose1d = wrap_module(nn.ConvTranspose1d) LogSigmoid = blackbox_module(nn.LogSigmoid)
ConvTranspose2d = wrap_module(nn.ConvTranspose2d) Softplus = blackbox_module(nn.Softplus)
ConvTranspose3d = wrap_module(nn.ConvTranspose3d) Softshrink = blackbox_module(nn.Softshrink)
Threshold = wrap_module(nn.Threshold) MultiheadAttention = blackbox_module(nn.MultiheadAttention)
ReLU = wrap_module(nn.ReLU) PReLU = blackbox_module(nn.PReLU)
Hardtanh = wrap_module(nn.Hardtanh) Softsign = blackbox_module(nn.Softsign)
ReLU6 = wrap_module(nn.ReLU6) Softmin = blackbox_module(nn.Softmin)
Sigmoid = wrap_module(nn.Sigmoid) Tanhshrink = blackbox_module(nn.Tanhshrink)
Tanh = wrap_module(nn.Tanh) RReLU = blackbox_module(nn.RReLU)
Softmax = wrap_module(nn.Softmax) AvgPool1d = blackbox_module(nn.AvgPool1d)
Softmax2d = wrap_module(nn.Softmax2d) AvgPool2d = blackbox_module(nn.AvgPool2d)
LogSoftmax = wrap_module(nn.LogSoftmax) AvgPool3d = blackbox_module(nn.AvgPool3d)
ELU = wrap_module(nn.ELU) MaxPool1d = blackbox_module(nn.MaxPool1d)
SELU = wrap_module(nn.SELU) MaxPool2d = blackbox_module(nn.MaxPool2d)
CELU = wrap_module(nn.CELU) MaxPool3d = blackbox_module(nn.MaxPool3d)
GLU = wrap_module(nn.GLU) MaxUnpool1d = blackbox_module(nn.MaxUnpool1d)
GELU = wrap_module(nn.GELU) MaxUnpool2d = blackbox_module(nn.MaxUnpool2d)
Hardshrink = wrap_module(nn.Hardshrink) MaxUnpool3d = blackbox_module(nn.MaxUnpool3d)
LeakyReLU = wrap_module(nn.LeakyReLU) FractionalMaxPool2d = blackbox_module(nn.FractionalMaxPool2d)
LogSigmoid = wrap_module(nn.LogSigmoid) FractionalMaxPool3d = blackbox_module(nn.FractionalMaxPool3d)
Softplus = wrap_module(nn.Softplus) LPPool1d = blackbox_module(nn.LPPool1d)
Softshrink = wrap_module(nn.Softshrink) LPPool2d = blackbox_module(nn.LPPool2d)
MultiheadAttention = wrap_module(nn.MultiheadAttention) LocalResponseNorm = blackbox_module(nn.LocalResponseNorm)
PReLU = wrap_module(nn.PReLU) BatchNorm1d = blackbox_module(nn.BatchNorm1d)
Softsign = wrap_module(nn.Softsign) BatchNorm2d = blackbox_module(nn.BatchNorm2d)
Softmin = wrap_module(nn.Softmin) BatchNorm3d = blackbox_module(nn.BatchNorm3d)
Tanhshrink = wrap_module(nn.Tanhshrink) InstanceNorm1d = blackbox_module(nn.InstanceNorm1d)
RReLU = wrap_module(nn.RReLU) InstanceNorm2d = blackbox_module(nn.InstanceNorm2d)
AvgPool1d = wrap_module(nn.AvgPool1d) InstanceNorm3d = blackbox_module(nn.InstanceNorm3d)
AvgPool2d = wrap_module(nn.AvgPool2d) LayerNorm = blackbox_module(nn.LayerNorm)
AvgPool3d = wrap_module(nn.AvgPool3d) GroupNorm = blackbox_module(nn.GroupNorm)
MaxPool1d = wrap_module(nn.MaxPool1d) SyncBatchNorm = blackbox_module(nn.SyncBatchNorm)
MaxPool2d = wrap_module(nn.MaxPool2d) Dropout = blackbox_module(nn.Dropout)
MaxPool3d = wrap_module(nn.MaxPool3d) Dropout2d = blackbox_module(nn.Dropout2d)
MaxUnpool1d = wrap_module(nn.MaxUnpool1d) Dropout3d = blackbox_module(nn.Dropout3d)
MaxUnpool2d = wrap_module(nn.MaxUnpool2d) AlphaDropout = blackbox_module(nn.AlphaDropout)
MaxUnpool3d = wrap_module(nn.MaxUnpool3d) FeatureAlphaDropout = blackbox_module(nn.FeatureAlphaDropout)
FractionalMaxPool2d = wrap_module(nn.FractionalMaxPool2d) ReflectionPad1d = blackbox_module(nn.ReflectionPad1d)
FractionalMaxPool3d = wrap_module(nn.FractionalMaxPool3d) ReflectionPad2d = blackbox_module(nn.ReflectionPad2d)
LPPool1d = wrap_module(nn.LPPool1d) ReplicationPad2d = blackbox_module(nn.ReplicationPad2d)
LPPool2d = wrap_module(nn.LPPool2d) ReplicationPad1d = blackbox_module(nn.ReplicationPad1d)
LocalResponseNorm = wrap_module(nn.LocalResponseNorm) ReplicationPad3d = blackbox_module(nn.ReplicationPad3d)
BatchNorm1d = wrap_module(nn.BatchNorm1d) CrossMapLRN2d = blackbox_module(nn.CrossMapLRN2d)
BatchNorm2d = wrap_module(nn.BatchNorm2d) Embedding = blackbox_module(nn.Embedding)
BatchNorm3d = wrap_module(nn.BatchNorm3d) EmbeddingBag = blackbox_module(nn.EmbeddingBag)
InstanceNorm1d = wrap_module(nn.InstanceNorm1d) RNNBase = blackbox_module(nn.RNNBase)
InstanceNorm2d = wrap_module(nn.InstanceNorm2d) RNN = blackbox_module(nn.RNN)
InstanceNorm3d = wrap_module(nn.InstanceNorm3d) LSTM = blackbox_module(nn.LSTM)
LayerNorm = wrap_module(nn.LayerNorm) GRU = blackbox_module(nn.GRU)
GroupNorm = wrap_module(nn.GroupNorm) RNNCellBase = blackbox_module(nn.RNNCellBase)
SyncBatchNorm = wrap_module(nn.SyncBatchNorm) RNNCell = blackbox_module(nn.RNNCell)
Dropout = wrap_module(nn.Dropout) LSTMCell = blackbox_module(nn.LSTMCell)
Dropout2d = wrap_module(nn.Dropout2d) GRUCell = blackbox_module(nn.GRUCell)
Dropout3d = wrap_module(nn.Dropout3d) PixelShuffle = blackbox_module(nn.PixelShuffle)
AlphaDropout = wrap_module(nn.AlphaDropout) Upsample = blackbox_module(nn.Upsample)
FeatureAlphaDropout = wrap_module(nn.FeatureAlphaDropout) UpsamplingNearest2d = blackbox_module(nn.UpsamplingNearest2d)
ReflectionPad1d = wrap_module(nn.ReflectionPad1d) UpsamplingBilinear2d = blackbox_module(nn.UpsamplingBilinear2d)
ReflectionPad2d = wrap_module(nn.ReflectionPad2d) PairwiseDistance = blackbox_module(nn.PairwiseDistance)
ReplicationPad2d = wrap_module(nn.ReplicationPad2d) AdaptiveMaxPool1d = blackbox_module(nn.AdaptiveMaxPool1d)
ReplicationPad1d = wrap_module(nn.ReplicationPad1d) AdaptiveMaxPool2d = blackbox_module(nn.AdaptiveMaxPool2d)
ReplicationPad3d = wrap_module(nn.ReplicationPad3d) AdaptiveMaxPool3d = blackbox_module(nn.AdaptiveMaxPool3d)
CrossMapLRN2d = wrap_module(nn.CrossMapLRN2d) AdaptiveAvgPool1d = blackbox_module(nn.AdaptiveAvgPool1d)
Embedding = wrap_module(nn.Embedding) AdaptiveAvgPool2d = blackbox_module(nn.AdaptiveAvgPool2d)
EmbeddingBag = wrap_module(nn.EmbeddingBag) AdaptiveAvgPool3d = blackbox_module(nn.AdaptiveAvgPool3d)
RNNBase = wrap_module(nn.RNNBase) TripletMarginLoss = blackbox_module(nn.TripletMarginLoss)
RNN = wrap_module(nn.RNN) ZeroPad2d = blackbox_module(nn.ZeroPad2d)
LSTM = wrap_module(nn.LSTM) ConstantPad1d = blackbox_module(nn.ConstantPad1d)
GRU = wrap_module(nn.GRU) ConstantPad2d = blackbox_module(nn.ConstantPad2d)
RNNCellBase = wrap_module(nn.RNNCellBase) ConstantPad3d = blackbox_module(nn.ConstantPad3d)
RNNCell = wrap_module(nn.RNNCell) Bilinear = blackbox_module(nn.Bilinear)
LSTMCell = wrap_module(nn.LSTMCell) CosineSimilarity = blackbox_module(nn.CosineSimilarity)
GRUCell = wrap_module(nn.GRUCell) Unfold = blackbox_module(nn.Unfold)
PixelShuffle = wrap_module(nn.PixelShuffle) Fold = blackbox_module(nn.Fold)
Upsample = wrap_module(nn.Upsample) AdaptiveLogSoftmaxWithLoss = blackbox_module(nn.AdaptiveLogSoftmaxWithLoss)
UpsamplingNearest2d = wrap_module(nn.UpsamplingNearest2d) TransformerEncoder = blackbox_module(nn.TransformerEncoder)
UpsamplingBilinear2d = wrap_module(nn.UpsamplingBilinear2d) TransformerDecoder = blackbox_module(nn.TransformerDecoder)
PairwiseDistance = wrap_module(nn.PairwiseDistance) TransformerEncoderLayer = blackbox_module(nn.TransformerEncoderLayer)
AdaptiveMaxPool1d = wrap_module(nn.AdaptiveMaxPool1d) TransformerDecoderLayer = blackbox_module(nn.TransformerDecoderLayer)
AdaptiveMaxPool2d = wrap_module(nn.AdaptiveMaxPool2d) Transformer = blackbox_module(nn.Transformer)
AdaptiveMaxPool3d = wrap_module(nn.AdaptiveMaxPool3d) Flatten = blackbox_module(nn.Flatten)
AdaptiveAvgPool1d = wrap_module(nn.AdaptiveAvgPool1d) Hardsigmoid = blackbox_module(nn.Hardsigmoid)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = wrap_module(nn.AdaptiveAvgPool3d) if version_larger_equal(torch.__version__, '1.6.0'):
TripletMarginLoss = wrap_module(nn.TripletMarginLoss) Hardswish = blackbox_module(nn.Hardswish)
ZeroPad2d = wrap_module(nn.ZeroPad2d)
ConstantPad1d = wrap_module(nn.ConstantPad1d) if version_larger_equal(torch.__version__, '1.7.0'):
ConstantPad2d = wrap_module(nn.ConstantPad2d) SiLU = blackbox_module(nn.SiLU)
ConstantPad3d = wrap_module(nn.ConstantPad3d) Unflatten = blackbox_module(nn.Unflatten)
Bilinear = wrap_module(nn.Bilinear) TripletMarginWithDistanceLoss = blackbox_module(nn.TripletMarginWithDistanceLoss)
CosineSimilarity = wrap_module(nn.CosineSimilarity)
Unfold = wrap_module(nn.Unfold)
Fold = wrap_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = wrap_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = wrap_module(nn.TransformerEncoder)
TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
Flatten = wrap_module(nn.Flatten)
#Unflatten = wrap_module(nn.Unflatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
Hardswish = wrap_module(nn.Hardswish)
#SiLU = wrap_module(nn.SiLU)
#TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
...@@ -121,6 +121,8 @@ class PyTorchOperation(Operation): ...@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
return f'{output} = {value}' return f'{output} = {value}'
elif self.type == 'prim::ListConstruct': elif self.type == 'prim::ListConstruct':
return f'{output} = [{", ".join(inputs)}]' return f'{output} = [{", ".join(inputs)}]'
elif self.type == 'prim::GetAttr':
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
elif self.type == 'aten::mean': elif self.type == 'aten::mean':
return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})' return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})'
elif self.type == 'aten::__getitem__': elif self.type == 'aten::__getitem__':
...@@ -133,8 +135,7 @@ class PyTorchOperation(Operation): ...@@ -133,8 +135,7 @@ class PyTorchOperation(Operation):
assert len(inputs) == 2 assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})' return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
elif self.type == 'aten::add': elif self.type == 'aten::add':
assert len(inputs) == 2 return f'{output} = ' + ' + '.join(inputs)
return f'{output} = {inputs[0]} + {inputs[1]}'
elif self.type == OpTypeName.MergedSlice: elif self.type == OpTypeName.MergedSlice:
assert (len(inputs) - 1) % 4 == 0 assert (len(inputs) - 1) % 4 == 0
slices = [] slices = []
...@@ -151,6 +152,8 @@ class PyTorchOperation(Operation): ...@@ -151,6 +152,8 @@ class PyTorchOperation(Operation):
return f'{output} = {inputs[0]}.view({inputs[1]})' return f'{output} = {inputs[0]}.view({inputs[1]})'
elif self.type == 'aten::slice': elif self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation') raise RuntimeError('not supposed to have aten::slice operation')
elif self.type == 'aten::Bool':
return f'{output} = bool({inputs[0]})'
else: else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}') raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
......
from .tpe_strategy import TPEStrategy from .tpe_strategy import TPEStrategy
from .random_strategy import RandomStrategy
import logging
import random
import time
from .. import Sampler, submit_models, query_available_resources
from .strategy import BaseStrategy
_logger = logging.getLogger(__name__)
class RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return random.choice(candidates)
class RandomStrategy(BaseStrategy):
def __init__(self):
self.random_sampler = RandomSampler()
def run(self, base_model, applied_mutators):
_logger.info('stargety start...')
while True:
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: %s', str(applied_mutators))
for mutator in applied_mutators:
mutator.bind_sampler(self.random_sampler)
model = mutator.apply(model)
# run models
submit_models(model)
else:
time.sleep(2)
import logging import logging
import time
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
from .. import Sampler, submit_models, wait_models from .. import Sampler, submit_models, query_available_resources, is_stopped_exec
from .strategy import BaseStrategy from .strategy import BaseStrategy
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -39,6 +40,7 @@ class TPEStrategy(BaseStrategy): ...@@ -39,6 +40,7 @@ class TPEStrategy(BaseStrategy):
def __init__(self): def __init__(self):
self.tpe_sampler = TPESampler() self.tpe_sampler = TPESampler()
self.model_id = 0 self.model_id = 0
self.running_models = {}
def run(self, base_model, applied_mutators): def run(self, base_model, applied_mutators):
sample_space = [] sample_space = []
...@@ -48,9 +50,10 @@ class TPEStrategy(BaseStrategy): ...@@ -48,9 +50,10 @@ class TPEStrategy(BaseStrategy):
sample_space.extend(recorded_candidates) sample_space.extend(recorded_candidates)
self.tpe_sampler.update_sample_space(sample_space) self.tpe_sampler.update_sample_space(sample_space)
try: _logger.info('stargety start...')
_logger.info('stargety start...') while True:
while True: avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model model = base_model
_logger.info('apply mutators...') _logger.info('apply mutators...')
_logger.info('mutators: %s', str(applied_mutators)) _logger.info('mutators: %s', str(applied_mutators))
...@@ -61,9 +64,18 @@ class TPEStrategy(BaseStrategy): ...@@ -61,9 +64,18 @@ class TPEStrategy(BaseStrategy):
model = mutator.apply(model) model = mutator.apply(model)
# run models # run models
submit_models(model) submit_models(model)
wait_models(model) self.running_models[self.model_id] = model
self.tpe_sampler.receive_result(self.model_id, model.metric)
self.model_id += 1 self.model_id += 1
_logger.info('Strategy says: %s', model.metric) else:
except Exception: time.sleep(2)
_logger.error(logging.exception('message'))
_logger.warning('num of running models: %d', len(self.running_models))
to_be_deleted = []
for _id, _model in self.running_models.items():
if is_stopped_exec(_model):
if _model.metric is not None:
self.tpe_sampler.receive_result(_id, _model.metric)
_logger.warning('tpe receive results: %d, %s', _id, _model.metric)
to_be_deleted.append(_id)
for _id in to_be_deleted:
del self.running_models[_id]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment