"src/value/git@developer.sourcefind.cn:gaoqiong/yaml-cpp.git" did not exist on "e32b3cd93f61fd2e81bfa6dd48475a563ded2722"
Unverified Commit d50b4665 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Add python name as Node attribute of graph_gen (#4243)

parent 068775f3
...@@ -14,7 +14,8 @@ from .op_types import MODULE_EXCEPT_LIST, OpTypeName ...@@ -14,7 +14,8 @@ from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import ( from .utils import (
_convert_name, build_full_name, _without_shape_info, _convert_name, build_full_name, _without_shape_info,
_extract_info_from_trace_node, get_full_name_by_scope_name, _extract_info_from_trace_node, get_full_name_by_scope_name,
is_layerchoice_node, match_node, build_cand_name is_layerchoice_node, match_node, build_cand_name,
build_python_name
) )
...@@ -139,7 +140,7 @@ class GraphConverter: ...@@ -139,7 +140,7 @@ class GraphConverter:
hidden_node.remove() hidden_node.remove()
def handle_graph_nodes(self, script_module, sm_graph, def handle_graph_nodes(self, script_module, sm_graph,
module, module_name, module, module_name, module_python_name,
ir_model, ir_graph, ir_model, ir_graph,
shared_module_index=None): shared_module_index=None):
""" """
...@@ -317,10 +318,12 @@ class GraphConverter: ...@@ -317,10 +318,12 @@ class GraphConverter:
submodule_name, script_module._modules.keys()) submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name) submodule_full_name = build_full_name(module_name, submodule_name)
submodule_python_name = build_python_name(module_python_name, submodule_name)
submodule_obj = getattr(module, submodule_name) submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name], subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name],
submodule_obj, submodule_obj,
submodule_full_name, ir_model) submodule_full_name, submodule_python_name,
ir_model)
else: else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self) # %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8) # %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
...@@ -347,12 +350,14 @@ class GraphConverter: ...@@ -347,12 +350,14 @@ class GraphConverter:
assert predecessor.hasAttribute('name') assert predecessor.hasAttribute('name')
module_name_space.append(predecessor.s('name')) module_name_space.append(predecessor.s('name'))
submodule_full_name = build_full_name(module_name, list(reversed(module_name_space))) submodule_full_name = build_full_name(module_name, list(reversed(module_name_space)))
submodule_python_name = build_python_name(module_python_name, list(reversed(module_name_space)))
submodule_obj = module submodule_obj = module
script_submodule = script_module script_submodule = script_module
for each_name in list(reversed(module_name_space)): for each_name in list(reversed(module_name_space)):
submodule_obj = getattr(submodule_obj, each_name) submodule_obj = getattr(submodule_obj, each_name)
script_submodule = script_submodule._modules[each_name] script_submodule = script_submodule._modules[each_name]
subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model) subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name,
submodule_python_name, ir_model)
else: else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str())) raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
...@@ -362,13 +367,16 @@ class GraphConverter: ...@@ -362,13 +367,16 @@ class GraphConverter:
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}} # example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
self.global_seq += 1 self.global_seq += 1
shared_node_name = build_full_name(submodule_full_name, '', self.global_seq) shared_node_name = build_full_name(submodule_full_name, '', self.global_seq)
shared_node_python_name = build_python_name(submodule_python_name, self.global_seq)
shared_type_operation = Operation.new('shared', {'reference': submodule_full_name}) shared_type_operation = Operation.new('shared', {'reference': submodule_full_name})
subcell = ir_graph.add_node(shared_node_name, shared_type_operation) subcell = ir_graph.add_node(shared_node_name, shared_type_operation)
subcell.python_name = shared_node_python_name
else: else:
# this module is processed for the first time, build cell for it # this module is processed for the first time, build cell for it
if subgraph is None: if subgraph is None:
# if we do not parse this module's graph, we create Node for this module # if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs) subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
subcell.python_name = submodule_python_name
if isinstance(submodule_obj, Placeholder): if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label) subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, InputChoice): elif isinstance(submodule_obj, InputChoice):
...@@ -377,6 +385,7 @@ class GraphConverter: ...@@ -377,6 +385,7 @@ class GraphConverter:
# Graph already created, create Cell for it # Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs) new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell) subcell = ir_graph.add_node(submodule_full_name, new_cell)
subcell.python_name = submodule_python_name
shared_module_index[submodule_full_name] = subcell shared_module_index[submodule_full_name] = subcell
node_index[node] = subcell node_index[node] = subcell
# connect the cell into graph # connect the cell into graph
...@@ -391,7 +400,7 @@ class GraphConverter: ...@@ -391,7 +400,7 @@ class GraphConverter:
# step #1: generate graph ir for this method # step #1: generate graph ir for this method
method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True) method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
self.handle_graph_nodes(script_module, script_method.graph, module, self.handle_graph_nodes(script_module, script_method.graph, module,
module_name, ir_model, method_ir_graph, shared_module_index) module_name, module_python_name, ir_model, method_ir_graph, shared_module_index)
self.refine_graph(method_ir_graph) self.refine_graph(method_ir_graph)
# step #2: merge this graph to its module graph # step #2: merge this graph to its module graph
...@@ -439,6 +448,8 @@ class GraphConverter: ...@@ -439,6 +448,8 @@ class GraphConverter:
self.global_seq += 1 self.global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq), func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
'{}.{}'.format(func_type_str, func_name)) '{}.{}'.format(func_type_str, func_name))
func_python_name = build_python_name(module_python_name, func_name)
func_node.python_name = func_python_name
node_index[node] = func_node node_index[node] = func_node
self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True) self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant': elif node.kind() == 'prim::Constant':
...@@ -480,7 +491,10 @@ class GraphConverter: ...@@ -480,7 +491,10 @@ class GraphConverter:
# handle aten::XXX # handle aten::XXX
self.global_seq += 1 self.global_seq += 1
aten_op_name = node.kind().replace('::', '__') aten_op_name = node.kind().replace('::', '__')
aten_op_python_name = node.kind().replace('aten::', '')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind()) aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
aten_python_name = build_python_name(module_python_name, aten_op_python_name)
aten_node.python_name = aten_python_name
node_index[node] = aten_node node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap) self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
else: else:
...@@ -587,25 +601,29 @@ class GraphConverter: ...@@ -587,25 +601,29 @@ class GraphConverter:
'accessor': module._accessor 'accessor': module._accessor
} }
def _convert_module(self, script_module, module, module_name, ir_model): def _convert_module(self, script_module, module, module_name, module_python_name, ir_model):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module # NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice # also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name original_type_name = script_module.original_name
m_attrs = None m_attrs = None
if original_type_name == OpTypeName.LayerChoice: if original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
graph.python_name = module_python_name
candidate_name_list = [] candidate_name_list = []
for cand_name in module.names: for cand_name in module.names:
cand = module[cand_name] cand = module[cand_name]
script_cand = script_module._modules[cand_name] script_cand = script_module._modules[cand_name]
cand_name = build_cand_name(cand_name, module.label) cand_full_name = build_cand_name(cand_name, module.label)
candidate_name_list.append(cand_name) cand_python_name = build_python_name(module_python_name, cand_name)
subgraph, attrs = self._convert_module(script_cand, cand, cand_name, ir_model) candidate_name_list.append(cand_full_name)
subgraph, attrs = self._convert_module(script_cand, cand, cand_full_name, cand_python_name, ir_model)
if subgraph is not None: if subgraph is not None:
graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs)) cand_node = graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs))
cand_node.python_name = cand_python_name
else: else:
cand_type = '__torch__.' + get_importable_name(cand.__class__) cand_type = '__torch__.' + get_importable_name(cand.__class__)
graph.add_node(cand_name, cand_type, attrs) cand_node = graph.add_node(cand_full_name, cand_type, attrs)
cand_node.python_name = cand_python_name
graph._register() graph._register()
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list} return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
elif original_type_name == OpTypeName.InputChoice: elif original_type_name == OpTypeName.InputChoice:
...@@ -629,10 +647,11 @@ class GraphConverter: ...@@ -629,10 +647,11 @@ class GraphConverter:
sm_graph = script_module.graph sm_graph = script_module.graph
self.global_graph_id += 1 self.global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True) ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
ir_graph.python_name = module_python_name
# handle graph nodes # handle graph nodes
self.handle_graph_nodes(script_module, sm_graph, module, self.handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph) module_name, module_python_name, ir_model, ir_graph)
self.refine_graph(ir_graph) self.refine_graph(ir_graph)
ir_graph._register() ir_graph._register()
...@@ -671,8 +690,7 @@ class GraphConverter: ...@@ -671,8 +690,7 @@ class GraphConverter:
dict dict
the input arguments of this module the input arguments of this module
""" """
return self._convert_module(script_module, module, module_name, None, ir_model)
return self._convert_module(script_module, module, module_name, ir_model)
class GraphConverterWithShape(GraphConverter): class GraphConverterWithShape(GraphConverter):
...@@ -691,7 +709,7 @@ class GraphConverterWithShape(GraphConverter): ...@@ -691,7 +709,7 @@ class GraphConverterWithShape(GraphConverter):
def convert_module(self, script_module, module, module_name, ir_model, dummy_input): def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval() module.eval()
ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model) ir_graph, attrs = self._convert_module(script_module, module, module_name, None, ir_model)
self.remove_dummy_nodes(ir_model) self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model) self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, dummy_input) self._trace_module(module, module_name, ir_model, dummy_input)
......
...@@ -14,6 +14,15 @@ def build_full_name(prefix, name, seq=None): ...@@ -14,6 +14,15 @@ def build_full_name(prefix, name, seq=None):
return '{}__{}{}'.format(prefix, name, str(seq)) return '{}__{}{}'.format(prefix, name, str(seq))
def build_python_name(prefix, name):
if isinstance(name, list):
name = '.'.join(name)
if prefix:
return '{}.{}'.format(prefix, name)
else: # predix could be None
return name
def build_cand_name(name, label): def build_cand_name(name, label):
return f'layerchoice_{label}_{name}' return f'layerchoice_{label}_{name}'
......
...@@ -212,6 +212,20 @@ class Model: ...@@ -212,6 +212,20 @@ class Model:
else: else:
return None return None
def get_node_by_python_name(self, python_name: str) -> 'Node':
"""
Traverse all the nodes to find the matched node with the given python_name.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_python_name(python_name)
matched_nodes.extend(nodes)
# assert len(matched_nodes) <= 1
if matched_nodes:
return matched_nodes[0]
else:
return None
def get_cell_nodes(self) -> List['Node']: def get_cell_nodes(self) -> List['Node']:
matched_nodes = [] matched_nodes = []
for graph in self.graphs.values(): for graph in self.graphs.values():
...@@ -274,6 +288,8 @@ class Graph: ...@@ -274,6 +288,8 @@ class Graph:
All input/output/hidden nodes. All input/output/hidden nodes.
edges edges
... ...
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
""" """
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False): def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
...@@ -283,6 +299,9 @@ class Graph: ...@@ -283,6 +299,9 @@ class Graph:
self.id: int = graph_id self.id: int = graph_id
self.name: str = name or f'_generated_{graph_id}' self.name: str = name or f'_generated_{graph_id}'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self.python_name: Optional[str] = None
self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _IOPseudoOperation('_inputs'), _internal=True) self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _IOPseudoOperation('_inputs'), _internal=True)
self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _IOPseudoOperation('_outputs'), _internal=True) self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _IOPseudoOperation('_outputs'), _internal=True)
self.hidden_nodes: List[Node] = [] self.hidden_nodes: List[Node] = []
...@@ -355,6 +374,13 @@ class Graph: ...@@ -355,6 +374,13 @@ class Graph:
found = [node for node in self.nodes if node.name == name] found = [node for node in self.nodes if node.name == name]
return found[0] if found else None return found[0] if found else None
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
"""
Returns the node which has specified python_name; or returns `None` if no node has this python_name.
"""
found = [node for node in self.nodes if node.python_name == python_name]
return found[0] if found else None
def get_nodes_by_type(self, operation_type: str) -> List['Node']: def get_nodes_by_type(self, operation_type: str) -> List['Node']:
""" """
Returns nodes whose operation is specified typed. Returns nodes whose operation is specified typed.
...@@ -374,6 +400,9 @@ class Graph: ...@@ -374,6 +400,9 @@ class Graph:
def get_nodes_by_name(self, name: str) -> List['Node']: def get_nodes_by_name(self, name: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.name == name] return [node for node in self.hidden_nodes if node.name == name]
def get_nodes_by_python_name(self, python_name: str) -> Optional['Node']:
return [node for node in self.nodes if node.python_name == python_name]
def topo_sort(self) -> List['Node']: def topo_sort(self) -> List['Node']:
node_to_fanin = {} node_to_fanin = {}
curr_nodes = [] curr_nodes = []
...@@ -423,9 +452,11 @@ class Graph: ...@@ -423,9 +452,11 @@ class Graph:
new_graph.output_node.operation.io_names = self.output_node.operation.io_names new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label) new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label) new_graph.output_node.update_label(self.output_node.label)
new_graph.python_name = self.python_name
for node in self.hidden_nodes: for node in self.hidden_nodes:
new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True) new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True)
new_node.python_name = node.python_name
new_node.update_label(node.label) new_node.update_label(node.label)
new_node._register() new_node._register()
...@@ -446,11 +477,13 @@ class Graph: ...@@ -446,11 +477,13 @@ class Graph:
new_graph.output_node.operation.io_names = self.output_node.operation.io_names new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label) new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label) new_graph.output_node.update_label(self.output_node.label)
new_graph.python_name = self.python_name
id_to_new_node = {} # old node ID -> new node object id_to_new_node = {} # old node ID -> new node object
for old_node in self.hidden_nodes: for old_node in self.hidden_nodes:
new_node = Node(new_graph, uid(), None, old_node.operation, _internal=True)._register() new_node = Node(new_graph, uid(), None, old_node.operation, _internal=True)._register()
new_node.python_name = old_node.python_name
new_node.update_label(old_node.label) new_node.update_label(old_node.label)
id_to_new_node[old_node.id] = new_node id_to_new_node[old_node.id] = new_node
...@@ -514,6 +547,8 @@ class Node: ...@@ -514,6 +547,8 @@ class Node:
If two models have nodes with same ID, they are semantically the same node. If two models have nodes with same ID, they are semantically the same node.
name name
Mnemonic name. It should have an one-to-one mapping with ID. Mnemonic name. It should have an one-to-one mapping with ID.
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
label label
Optional. If two nodes have the same label, they are considered same by the mutator. Optional. If two nodes have the same label, they are considered same by the mutator.
operation operation
...@@ -535,13 +570,15 @@ class Node: ...@@ -535,13 +570,15 @@ class Node:
self.graph: Graph = graph self.graph: Graph = graph
self.id: int = node_id self.id: int = node_id
self.name: str = name or f'_generated_{node_id}' self.name: str = name or f'_generated_{node_id}'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self.python_name: Optional[str] = None
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug # TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release # maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation self.operation: Operation = operation
self.label: Optional[str] = None self.label: Optional[str] = None
def __repr__(self): def __repr__(self):
return f'Node(id={self.id}, name={self.name}, label={self.label}, operation={self.operation})' return f'Node(id={self.id}, name={self.name}, python_name={self.python_name}, label={self.label}, operation={self.operation})'
@property @property
def predecessors(self) -> List['Node']: def predecessors(self) -> List['Node']:
...@@ -626,6 +663,8 @@ class Node: ...@@ -626,6 +663,8 @@ class Node:
ret['operation']['cell_name'] = self.operation.cell_name ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None: if self.label is not None:
ret['label'] = self.label ret['label'] = self.label
if self.python_name is not None:
ret['python_name'] = self.python_name
return ret return ret
......
...@@ -1232,5 +1232,33 @@ class TestPytorch(unittest.TestCase, ConvertMixin): ...@@ -1232,5 +1232,33 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
x = torch.randn(5, 3, 2) x = torch.randn(5, 3, 2)
self.run_test(SizeModel(10, 5), (x, )) self.run_test(SizeModel(10, 5), (x, ))
def test_python_name(self):
from .inject_nn import inject_pytorch_nn, remove_inject_pytorch_nn
try:
inject_pytorch_nn()
torchvision_model_zoo = {
'resnet18': torchvision.models.resnet18(),
'alexnet': torchvision.models.alexnet(),
'vgg16': torchvision.models.vgg16(),
'squeezenet': torchvision.models.squeezenet1_0(),
'shufflenet_v2': torchvision.models.shufflenet_v2_x1_0(),
'mobilenet_v2': torchvision.models.mobilenet_v2(),
'resnext50_32x4d': torchvision.models.resnext50_32x4d(),
'wide_resnet50_2': torchvision.models.wide_resnet50_2(),
'mnasnet': torchvision.models.mnasnet1_0(),
}
dummy_input=torch.randn(1, 3, 224, 224)
for model in torchvision_model_zoo.values():
model_ir = self._convert_model(model, dummy_input)
current_name = [node.python_name for node in model_ir.get_nodes() if node.python_name]
mentioned = set()
for k in model.state_dict():
k = ".".join(k.split(".")[:-1])
if k not in mentioned:
assert k in current_name, f'{k} not in state_name'
mentioned.add(k)
finally:
remove_inject_pytorch_nn()
class TestPytorchWithShape(TestPytorch, ConvertWithShapeMixin): class TestPytorchWithShape(TestPytorch, ConvertWithShapeMixin):
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment