Unverified Commit 8af73146 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] pytorch code converter (#3052)

parent 002af91f
......@@ -26,6 +26,7 @@ import subprocess
import re
import json
import requests
import yaml
__all__ = [
'Experiment',
......@@ -265,6 +266,38 @@ class Experiment:
self._endpoint = 'http://localhost:{}'.format(self._port)
self._exp_id = self.get_experiment_profile()['id']
def tmp_start_retiarii(self, graph_ir, training_approach,
applied_mutators, strategy, exp_config):
# prepare search space file which includes base graph IR and mutators
search_space = {}
search_space['base_model_ir'] = graph_ir
search_space['applied_mutators'] = applied_mutators
search_space['training_approach'] = training_approach
with open('search_space.json', 'w') as f:
json.dump(search_space, f)
# add advisor config to exp_config
exp_config['searchSpacePath'] = 'search_space.json'
exp_config['useAnnotation'] = False
exp_config['advisor'] = {
'codeDir': '.',
'classFileName': 'advisor_entry.py',
'className': 'RetiariiAdvisor',
'classArgs': {
'strategy': '{}.{}'.format(strategy['filename'], strategy['funcname'])
}
}
# add trial config to exp_config
exp_config['trial'] = {
'command': 'python3 -m nni.retiarii.trial_entry',
'codeDir': '../..',
'gpuNum': 0
}
# dump exp_config to nni.yml
with open('nni.yml', 'w') as f:
yaml.dump(exp_config, f)
# start experiment
self.start_experiment('nni.yml')
def start_experiment(self, config_file, port=None, debug=False):
"""
Start an experiment with specified configuration file and connect to it.
......
from .execution import *
from .graph import *
from .mutator import *
from .operation import *
from .model_apis import nn
......@@ -3,11 +3,27 @@ from typing import *
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell
# TODO: fix: inputs is a list, how to deal with single element list and single element
def model_to_pytorch_script(model: Model) -> str:
graphs = [graph_to_pytorch_model(name, cell) for name, cell in model.graphs.items()]
return _PyTorchScriptTemplate.format('\n\n'.join(graphs)).strip()
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
# TODO: set correct PATH for the packages (after launch refactor)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
def _convert_names(names: List[str]) -> List[str]:
return [_convert_name(name) for name in names]
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
......@@ -21,8 +37,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> str:
def _format_inputs(node: Node) -> List[str]:
edges = _sorted_incoming_edges(node)
inputs = []
for edge in edges:
......@@ -41,40 +56,48 @@ def _format_inputs(node: Node) -> str:
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return ', '.join(inputs)
return inputs
def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
nodes = graph.nodes # FIXME: topological sort is needed here
# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
for node in nodes:
if node.operation:
node_codes.append(node.operation.to_init_code(node.name))
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(node.name)
if node_code is not None:
node_codes.append(node_code)
if graph.input_names is None:
input_code = '*_inputs'
else:
input_code = ', '.join(graph.input_names)
# TODO: remove _convert_names (after merging input_names and input_node)
input_code = ', '.join(_convert_names(graph.input_names))
edge_codes = []
for node in nodes:
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs = _format_inputs(node)
edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs))
output_code = _format_inputs(graph.output_node)
if not output_code:
output_code = 'None'
# TODO: refactor graph output_node
output_names = _format_inputs(graph.output_node)
output_names = _convert_names(output_names)
if not output_names:
output_names = ['None']
linebreak = '\n '
return _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else _convert_name(graph_name)),
inputs=input_code,
outputs=output_code,
outputs=', '.join(output_names),
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
......@@ -88,6 +111,11 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sys
sys.path.append("test/convert_test")
{}
{}
'''
......
# PyTorch Graph Converter
## Namespace for PyTorch Graph
We should have a concrete rule for specifying nodes in graph with namespace.
Each node has a name, either specified or generated. The nodes in the same hierarchy cannot have the same name.
* The name of module node natively follows this rule, because we use variable name for instantiated modules like what PyTorch graph does.
* For the nodes created in `forward` function, we use a global sequence number.
### Namespace for mutated (new) nodes
TBD
## Graph Simplification
TBD
## Node Types
We define concrete type string for each node type.
## Module's Input Arguments
We use wrapper to obtain the input arguments of modules. Users need to use our wrapped "nn" and wrapped "Module".
## Control Flow
### for loop
Currently, we only support `ModuleList` (`ModuleDict`) based for loop, which is automatically unfolded by TorchScript. That is to say, we do not support loop in TorchScript for now.
### if/else
For now, we only deal with the case that the condition is constant or attribute. In this case, only one branch is kept during generating the graph.
\ No newline at end of file
from .graph_gen import convert_to_graph
from .visualize import visualize_model
\ No newline at end of file
import json_tricks
import re
import torch
from ..graph import Graph, Node, Edge, Model
from ..operation import Cell, Operation
from ..model_apis.nn import Placeholder
from .op_types import RETIARII_BASE_OPS, MODULE_EXCEPT_LIST, Type
from .utils import build_full_name
global_seq = 0
global_graph_id = 0
modules_arg = None
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, ignore_first=False):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
ignore_first : bool
if it is true, skip the first input
"""
new_node_input_idx = 0
for _input in node.inputs():
if ignore_first:
ignore_first = False
continue
# handle source node
if _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs = [_output for _output in predecessor_node.outputs()]
if len(predecessor_outputs) == 1:
idx = None
else:
idx = predecessor_outputs.index(_input)
ir_predecessor_node = node_index[predecessor_node]
src_node_idx = idx
# get source node
# the input is output of a basic node
assert isinstance(ir_predecessor_node, Node)
src_node = ir_predecessor_node
# handle destination node
dst_node = new_node
dst_node_idx = new_node_input_idx
# create edge
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
new_node_input_idx += 1
def _handle_inputs(ir_graph, node, graph_inputs, node_index, module_name, ignore_first=False):
"""
create prim::GetAttr node when necessary. because for some cases prim::GetAttr nodes are removed,
for example, the prim::GetAttr used in prim::CallMethod
"""
global global_seq
for _input in node.inputs():
# for CallMethod and CallFunction
if ignore_first:
ignore_first = False
continue
if _input in graph_inputs:
continue
if _input.node().kind() == 'prim::Constant':
assert _input.node() in node_index
if _input.node().kind() == 'prim::GetAttr':
if _input.node() not in node_index:
node_type, attrs = handle_prim_attr_node(_input.node())
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.Attr, global_seq),
node_type, attrs)
node_index[_input.node()] = new_node
print('==handle inputs getattr==: ', _input.node())
def create_prim_constant_node(ir_graph, node, module_name):
global global_seq
attrs = {}
if node.outputsAt(0).toIValue() is not None:
attrs = {'value': node.outputsAt(0).toIValue()}
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.Constant, global_seq),
node.kind(), attrs)
return new_node
def handle_prim_attr_node(node):
assert node.hasAttribute('name')
assert node.inputsAt(0).debugName() == 'self'
assert node.inputsAt(0).unique() == 0
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs
def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(ir_graph):
"""
Parameters
----------
ir_graph : Graph
our ir graph representation
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
to_removes.append(hidden_node)
# some constant is not used, for example, function name as prim::Constant
assert hidden_node.operation.type == 'prim::Constant', 'the type is {}'.format(hidden_node.operation.type)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
sm_graph : torch._C.Graph
module : nn.Module
module_name : str
ir_model : Model
ir_graph : Graph
"""
# handle inputs
graph_inputs = []
for _input in sm_graph.inputs():
if _input.debugName() == 'self':
assert _input.unique() == 0
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_input.debugName())
node_index = {} # graph node to graph ir node
def handle_if_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created ir node
"""
# only deal with input of prim::If is constant or attribute for now
# TODO: support constant expression
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
if not inputs[0].node().kind() in ['prim::Constant', 'prim::GetAttr']:
raise RuntimeError('"if" whose condition is not constant or attribute has not been supported yet!')
chosen_block = None
if inputs[0].node().kind() == 'prim::Constant':
chosen_block = 0 if inputs[0].toIValue() else 1
if inputs[0].node().kind() == 'prim::GetAttr':
chosen_block = 0 if getattr(module, inputs[0].node().s('name')) else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
assert last_block_node is not None
return last_block_node
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created ir node
"""
global global_seq
if node.kind() == 'prim::CallMethod':
# get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name')
if node.s('name') == 'forward':
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = _remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node()
assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name')
assert submodule_name in script_module._modules
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
# TODO: try not-connected placeholder in TorchScript
# TODO: match subgraph with maintained graphs
# build cell
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
node_index[node] = subcell
_handle_inputs(ir_graph, node, graph_inputs, node_index, module_name, ignore_first=True)
# connect the cell into graph
_add_edge(ir_graph, node, graph_inputs, node_index, subcell, ignore_first=True)
else:
raise RuntimeError('unsupported CallMethod {}'.format(node.s('name')))
elif node.kind() == 'prim::CallFunction':
func_type_str = _remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, global_seq),
'{}.{}'.format(func_type_str, func_name))
node_index[node] = func_node
_handle_inputs(ir_graph, node, graph_inputs, node_index, module_name, ignore_first=True)
_add_edge(ir_graph, node, graph_inputs, node_index, func_node, ignore_first=True)
elif node.kind() == 'prim::Constant':
# TODO: how about calling a function twice? two constant nodes or one?
new_node = create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.ListConstruct, global_seq), node.kind())
node_index[node] = new_node
_handle_inputs(ir_graph, node, graph_inputs, node_index, module_name)
_add_edge(ir_graph, node, graph_inputs, node_index, new_node)
elif node.kind().startswith('aten::'):
# handle aten::XXX
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, Type.BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_handle_inputs(ir_graph, node, graph_inputs, node_index, module_name)
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node)
elif node.kind() == 'prim::Loop':
raise RuntimeError('Loop has not been supported yet!')
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
node_index[node] = last_block_node
elif node.kind() == 'prim::GetAttr':
pass
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
if node in node_index:
return node_index[node]
else:
return None
for node in sm_graph.nodes():
handle_single_node(node)
return node_index
def convert_module(script_module, module, module_name, ir_model):
global global_graph_id
global modules_arg
assert id(module) in modules_arg, 'id not exist: {}, {}'.format(id(module), module_name)
if isinstance(modules_arg[id(module)], tuple):
positional_args, keyword_args = modules_arg[id(module)]
m_attrs = keyword_args
# TODO: remove positional args
m_attrs['positional_args'] = positional_args
else:
m_attrs = modules_arg[id(module)]
original_type_name = script_module.original_name
if original_type_name in torch.nn.__dict__ and original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
return None, m_attrs
if original_type_name in RETIARII_BASE_OPS:
return None, m_attrs
# handle TorchScript graph
sm_graph = script_module.graph
global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=global_graph_id, name=module_name, _internal=True)
# handle graph nodes
node_index = handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph)
# handle graph outputs
graph_outputs = []
for _output in sm_graph.outputs():
graph_outputs.append(_output) # <class 'torch._C.Value'>
ir_graph._add_output(_output.debugName())
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
remove_unconnected_nodes(ir_graph)
ir_graph._register()
return ir_graph, m_attrs
def convert_to_graph(script_module, module, recorded_modules_arg):
"""
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module obtain with torch.jit.script
module : nn.Module
the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns
Model
the constructed IR model
"""
global modules_arg
modules_arg = recorded_modules_arg
model = Model(_internal=True)
module_name = '_model'
graph, m_attrs = convert_module(script_module, module, module_name, model)
return model
MODULE_EXCEPT_LIST = ['Sequential']
RETIARII_BASE_OPS = ['Placeholder']
class Type:
"""Node Type class
"""
Attr = 'Attr'
Constant = 'Constant'
ListConstruct = 'ListConstruct'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add'
}
BasicOpsTF = {}
\ No newline at end of file
def build_full_name(prefix, name, seq=None):
if seq is None:
return '{}.{}'.format(prefix, name)
else:
return '{}.{}{}'.format(prefix, name, str(seq))
\ No newline at end of file
import graphviz
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_training_config':
continue
with vgraph.subgraph(name='cluster'+name) as subgraph:
subgraph.attr(color='blue')
cell_node = {}
ioput = {'_inputs': '{}-{}'.format(name, '_'.join(graph['inputs'])),
'_outputs': '{}-{}'.format(name, '_'.join(graph['outputs']))}
subgraph.node(ioput['_inputs'])
subgraph.node(ioput['_outputs'])
for node_name, node_value in graph['nodes'].items():
value = node_value['operation']
if value['type'] == '_cell':
cell_input_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['inputs']))
cell_output_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['outputs']))
cell_node[node_name] = (cell_input_name, cell_output_name)
print('cell: ', node_name, cell_input_name, cell_output_name)
else:
subgraph.node(node_name)
for edge in graph['edges']:
src = edge['head'][0]
if src == '_inputs':
src = ioput['_inputs']
elif src in cell_node:
src = cell_node[src][1]
dst = edge['tail'][0]
if dst == '_outputs':
dst = ioput['_outputs']
elif dst in cell_node:
dst = cell_node[dst][0]
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='png')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()
\ No newline at end of file
import time
import importlib.util
from typing import *
from ..graph import Model, ModelStatus
......@@ -10,7 +11,8 @@ _execution_engine = None
_default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources']
'submit_models', 'wait_models', 'query_available_resources',
'get_base_model_ir', 'get_specified_mutators', 'get_trainer']
def get_execution_engine() -> BaseExecutionEngine:
......@@ -30,6 +32,34 @@ def get_and_register_default_listener(engine: AbstractExecutionEngine) -> Defaul
engine.register_graph_listener(_default_listener)
return _default_listener
def _get_search_space() -> 'Dict':
engine = get_execution_engine()
while True:
time.sleep(1)
if engine.get_search_space() is not None:
break
return engine.get_search_space()
def get_base_model_ir() -> 'Model':
search_space = _get_search_space()
return Model._load(search_space['base_model_ir'])
def get_specified_mutators() -> List['Mutator']:
search_space = _get_search_space()
applied_mutators = []
for each in search_space['applied_mutators']:
spec = importlib.util.spec_from_file_location("module.name", each['filepath'])
m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
#m.BlockMutator()
class_constructor = getattr(m, each['classname'])
mutator = class_constructor(**each['args'])
applied_mutators.append(mutator)
return applied_mutators
def get_trainer() -> 'BaseTrainer':
search_space = _get_search_space()
return search_space['training_approach']
def submit_models(*models: Model) -> None:
engine = get_execution_engine()
......
import logging
from typing import *
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
......@@ -5,6 +6,7 @@ from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData
from ..integration import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__)
class BaseGraphData:
def __init__(self, model_script: str, training_module: str, training_kwargs: Dict[str, Any]) -> None:
......@@ -48,6 +50,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self._running_models: Dict[int, Model] = dict()
def get_search_space(self) -> 'JSON':
advisor = get_advisor()
return advisor.search_space
def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model),
......@@ -59,11 +65,16 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners:
listener.on_resource_used(0) # FIXME: find the real resource id
_logger.warning('resources: {}'.format(listener.resources))
if not listener.has_available_resource():
_logger.warning('There is no available resource, but trial is submitted.')
listener.on_resource_used(1)
_logger.warning('on_resource_used: {}'.format(listener.resources))
def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners:
listener.on_resource_available([0] * num_trials) # FIXME: find the real resource id
listener.on_resource_available(1 * num_trials)
_logger.warning('on_resource_available: {}'.format(listener.resources))
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
......
......@@ -6,7 +6,10 @@ from .interface import *
class DefaultListener(AbstractGraphListener):
def __init__(self):
self.resources: List[WorkerInfo] = []
self.resources: int = 0 # simply resource count
def has_available_resource(self) -> bool:
return self.resources > 0
def on_metric(self, model: Model, metric: MetricData) -> None:
model.metric = metric
......@@ -20,8 +23,8 @@ class DefaultListener(AbstractGraphListener):
else:
model.status = ModelStatus.Failed
def on_resource_available(self, resources: List[WorkerInfo]) -> None:
def on_resource_available(self, resources: int) -> None:
self.resources += resources
def on_resource_used(self, resources: List[WorkerInfo]) -> None:
self.resources = [r for r in self.resources if r not in resources]
def on_resource_used(self, resources: int) -> None:
self.resources -= resources
......@@ -5,7 +5,7 @@ Model representation.
import copy
from enum import Enum
import json
from typing import (Any, Dict, List, Optional, Tuple, overload)
from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _PseudoOperation
......@@ -146,6 +146,29 @@ class Model:
ret['_training_config'] = self.training_config._dump()
return ret
def apply_trainer(self, module, args) -> None:
# TODO: rethink the way of specifying a trainer
self.training_config = TrainingConfig(module, args)
def get_nodes_by_label(self, label: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given name.
There could be multiple nodes with the same name. Name space name can uniquely
identify a graph or node.
NOTE: the implementation does not support the class abstration
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_label(label)
matched_nodes.extend(nodes)
return matched_nodes
def get_by_name(self, name: str) -> Union['Graph', 'Node']:
"""
Find the graph or node that have the given name space name.
"""
class ModelStatus(Enum):
"""
......@@ -210,6 +233,7 @@ class Graph:
self.id: int = graph_id
self.name: str = name or f'_generated_{graph_id}'
# TODO: why not merge the names into input_node and output_node???
self.input_names: Optional[List[str]] = None
self.output_names: Optional[List[str]] = None
......@@ -227,24 +251,54 @@ class Graph:
def nodes(self) -> List['Node']:
return [self.input_node, self.output_node] + self.hidden_nodes
# mutation
def _add_input(self, input_name) -> None:
if self.input_names is None:
self.input_names = [input_name]
else:
self.input_names.append(input_name)
def _add_output(self, output_name) -> None:
if self.output_names is None:
self.output_names = [output_name]
else:
self.output_names.append(output_name)
@overload
def add_node(self, operation: Operation) -> 'Node': ...
def add_node(self, name: str, operation: Operation) -> 'Node': ...
@overload
def add_node(self, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ...
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ...
def add_node(self, operation_or_type, parameters={}):
def add_node(self, name, operation_or_type, parameters={}):
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
op = Operation.new(operation_or_type, parameters)
return Node(self, self.model._uid(), None, op, _internal=True)._register()
op = Operation.new(operation_or_type, parameters, name)
return Node(self, self.model._uid(), name, op, _internal=True)._register()
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ...
def insert_node_on_edge(self, edge, name, operation_or_type, parameters={}) -> 'Node':
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
op = Operation.new(operation_or_type, parameters, name)
new_node = Node(self, self.model._uid(), name, op, _internal=True)._register()
# update edges
self.add_edge((edge.head, edge.head_slot), (new_node, None))
self.add_edge((new_node, None), (edge.tail, edge.tail_slot))
self.del_edge(edge)
return new_node
# mutation
def add_edge(self, head: Tuple['Node', Optional[int]], tail: Tuple['Node', Optional[int]]) -> 'Edge':
assert head[0].graph is self and tail[0].graph is self
return Edge(head, tail)._register()
return Edge(head, tail, _internal=True)._register()
def del_edge(self, edge: 'Edge') -> None:
self.edges.remove(edge)
def get_node_by_name(self, name: str) -> Optional['Node']:
"""
......@@ -259,8 +313,31 @@ class Graph:
"""
return [node for node in self.hidden_nodes if node.operation.type == operation_type]
def topo_sort(self) -> List['Node']: # TODO
...
def get_nodes_by_label(self, label: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.label == label]
def topo_sort(self) -> List['Node']:
node_to_fanin = {}
curr_nodes = []
for node in self.nodes:
fanin = len(node.incoming_edges)
node_to_fanin[node] = fanin
if fanin == 0:
curr_nodes.append(node)
sorted_nodes = []
while curr_nodes:
curr_node = curr_nodes.pop(0)
sorted_nodes.append(curr_node)
for successor in curr_node.successors:
node_to_fanin[successor] -= 1
if node_to_fanin[successor] == 0:
curr_nodes.append(successor)
for key in node_to_fanin:
assert node_to_fanin[key] == 0
return sorted_nodes
def fork(self) -> 'Graph':
"""
......@@ -278,7 +355,9 @@ class Graph:
new_graph.output_names = self.output_names
for node in self.hidden_nodes:
Node(new_graph, node.id, node.name, node.operation, _internal=True)._register()
new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True)
new_node.update_label(node.label)
new_node._register()
id_to_new_node = {node.id: node for node in new_graph.nodes}
......@@ -375,11 +454,12 @@ class Node:
def __init__(self, graph, node_id, name, operation, _internal=False):
self.graph: Graph = graph
self.id: int = node_id
self.name: str = name
self.name: str = name or f'_generated_{node_id}'
self.operation: Operation = operation
self.label: str = None
def __repr__(self):
return f'Node(id={self.id}, name={self.name}, operation={self.operation})'
return f'Node(id={self.id}, name={self.name}, label={self.label}, operation={self.operation})'
@property
def predecessors(self) -> List['Node']:
......@@ -402,7 +482,8 @@ class Node:
assert isinstance(self.operation, Cell)
return self.graph.model.graphs[self.operation.parameters['cell']]
# mutation
def update_label(self, label: str) -> None:
self.label = label
@overload
def update_operation(self, operation: Operation) -> None: ...
......@@ -433,22 +514,30 @@ class Node:
def __eq__(self, other: object) -> bool:
return self is other
def __hash__(self) -> int:
return hash(id(self))
def _register(self) -> 'Node':
self.graph.hidden_nodes.append(self)
return self
@staticmethod
def _load(graph: Graph, name: str, ir: Any) -> 'Node':
if ir['type'] == '_cell':
op = Cell(ir['cell'], ir.get('parameters', {}))
if ir['operation']['type'] == '_cell':
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}))
else:
op = Operation.new(ir['type'], ir.get('parameters', {}))
return Node(graph, graph.model._uid(), name, op)
op = Operation.new(ir['operation']['type'], ir['operation'].get('parameters', {}))
node = Node(graph, graph.model._uid(), name, op)
if 'label' in ir:
node.update_label(ir['label'])
return node
def _dump(self) -> Any:
ret = {'type': self.operation.type, 'parameters': self.operation.parameters}
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters}}
if isinstance(self.operation, Cell):
ret['cell'] = self.operation.cell_name
ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None:
ret['label'] = self.label
return ret
......
......@@ -47,6 +47,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
def __init__(self, strategy: Union[str, Callable]):
super(RetiariiAdvisor, self).__init__()
register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None
self.send_trial_callback: Callable[[dict], None] = None
self.request_trial_jobs_callback: Callable[[int], None] = None
......@@ -56,10 +57,19 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.strategy = utils.import_(strategy) if isinstance(strategy, str) else strategy
self.parameters_count = 0
_logger.info('Starting strategy...')
threading.Thread(target=self.strategy).start()
_logger.info('Strategy started!')
def handle_initialize(self, data):
pass
"""callback for initializing the advisor
Parameters
----------
data: dict
search space
"""
self.handle_update_search_space(data)
send(CommandType.Initialized, '')
def send_trial(self, parameters):
"""
......@@ -94,7 +104,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable
def handle_update_search_space(self, data):
pass
_logger.info('Received search space: {}'.format(data))
self.search_space = data
def handle_trial_end(self, data):
_logger.info('Trial end: {}'.format(data)) # do nothing
......
import inspect
import logging
import torch.nn as nn
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
#consoleHandler = logging.StreamHandler()
#consoleHandler.setLevel(logging.INFO)
#_logger.addHandler(consoleHandler)
_records = None
def enable_record_args():
global _records
_records = {}
_logger.info('args recording enabled')
def disable_record_args():
global _records
_records = None
_logger.info('args recording disabled')
def get_records():
global _records
return _records
class Placeholder(nn.Module):
def __init__(self, label, related_info):
global _records
if _records is not None:
_records[id(self)] = related_info
self.label = label
self.related_info = related_info
super(Placeholder, self).__init__()
def forward(self, x):
return x
class Module(nn.Module):
def __init__(self, *args, **kwargs):
# TODO: users have to pass init's arguments to super init's arguments
global _records
if _records is not None:
# TODO: change tuple to dict
_records[id(self)] = (args, kwargs)
#print('my module: ', id(self), args, kwargs)
super(Module, self).__init__()
class Sequential(nn.Sequential):
def __init__(self, *args):
global _records
if _records is not None:
_records[id(self)] = {} # no args need to be recorded
super(Sequential, self).__init__(*args)
def wrap_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
def __init__(self, *args, **kws):
global _records
if _records is not None:
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = args[i]
_records[id(self)] = full_args
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
Conv2d = wrap_module(nn.Conv2d)
BatchNorm2d = wrap_module(nn.BatchNorm2d)
ReLU = wrap_module(nn.ReLU)
Dropout = wrap_module(nn.Dropout)
Linear = wrap_module(nn.Linear)
from typing import (Any, Dict)
from typing import (Any, Dict, List)
from . import debug_configs
__all__ = ['Operation', 'Cell']
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
class Operation:
"""
......@@ -30,11 +35,10 @@ class Operation:
self.parameters: Dict[str, Any] = parameters
def to_init_code(self, field: str) -> str:
params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
return f'self.{field} = {self._to_class_name()}({params})'
raise NotImplementedError()
def to_forward_code(self, field: str, output: str, *inputs: str) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
raise NotImplementedError()
def _to_class_name(self) -> str:
raise NotImplementedError()
......@@ -43,9 +47,10 @@ class Operation:
return True
@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = {}) -> 'Operation':
def new(type_name: str, parameters: Dict[str, Any] = {}, cell_name: str = None) -> 'Operation':
if type_name == '_cell':
return Cell(parameters['cell'])
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters)
else:
if debug_configs.framework.lower() in ('torch', 'pytorch'):
from .operation_def import torch_op_def # pylint: disable=unused-import
......@@ -77,15 +82,60 @@ class Operation:
class PyTorchOperation(Operation):
def _to_class_name(self) -> str:
return 'nn.' + self.type
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):]
else:
return None
def get_import_pkg(self) -> str:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):].split('.')[0]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):].split('.')[0]
else:
return None
def to_init_code(self, field: str) -> str:
field = _convert_name(field)
if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
return f'self.{field} = {self._to_class_name()}({kw_params})'
return None
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
field = _convert_name(field)
output = _convert_name(output)
inputs = [_convert_name(_input) for _input in inputs]
if self._to_class_name() is not None:
return f'{output} = self.{field}({", ".join(inputs)})'
elif self.type.startswith('Function.'):
func_name = self.type[len('Function.'):]
return f'{output} = F.{func_name}({", ".join(inputs)})'
elif self.type == 'prim::Constant':
if self.parameters:
value = self.parameters['value']
else:
value = None
return f'{output} = {value}'
elif self.type == 'prim::ListConstruct':
return f'{output} = [{", ".join(inputs)}]'
elif self.type == 'aten::mean':
return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})'
else:
raise RuntimeError('unsupported operation type: {}'.format(self.type))
class TensorFlowOperation(Operation):
def _to_class_name(self) -> str:
return 'K.layers.' + self.type
class Cell(Operation):
class Cell(PyTorchOperation):
"""
TODO: this is pytorch cell
An operation reference to a subgraph.
Example code:
......@@ -122,7 +172,8 @@ class Cell(Operation):
self.parameters = parameters
def _to_class_name(self):
return self.cell_name
# TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name)
class _PseudoOperation(Operation):
......@@ -139,7 +190,7 @@ class _PseudoOperation(Operation):
def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def to_forward_code(self, field: str, output: str, *inputs: str) -> str:
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def __bool__(self) -> bool:
......
......@@ -31,6 +31,13 @@ def get_default_transform(dataset: str) -> Any:
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
if dataset == 'CIFAR10':
return transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# unsupported dataset, return None
return None
......
......@@ -62,7 +62,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
if sys.platform == 'win32':
node_command = os.path.join(entry_dir, 'node.exe')
else:
node_command = 'node'
node_command = os.path.join(entry_dir, 'node')
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform]
if mode == 'view':
cmds += ['--start_mode', 'resume']
......
import os
import sys
from nni.retiarii.integration import RetiariiAdvisor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment