Unverified Commit d50b4665 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Add python name as Node attribute of graph_gen (#4243)

parent 068775f3
......@@ -14,7 +14,8 @@ from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import (
_convert_name, build_full_name, _without_shape_info,
_extract_info_from_trace_node, get_full_name_by_scope_name,
is_layerchoice_node, match_node, build_cand_name
is_layerchoice_node, match_node, build_cand_name,
build_python_name
)
......@@ -139,7 +140,7 @@ class GraphConverter:
hidden_node.remove()
def handle_graph_nodes(self, script_module, sm_graph,
module, module_name,
module, module_name, module_python_name,
ir_model, ir_graph,
shared_module_index=None):
"""
......@@ -317,10 +318,12 @@ class GraphConverter:
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_python_name = build_python_name(module_python_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
submodule_full_name, submodule_python_name,
ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
......@@ -347,12 +350,14 @@ class GraphConverter:
assert predecessor.hasAttribute('name')
module_name_space.append(predecessor.s('name'))
submodule_full_name = build_full_name(module_name, list(reversed(module_name_space)))
submodule_python_name = build_python_name(module_python_name, list(reversed(module_name_space)))
submodule_obj = module
script_submodule = script_module
for each_name in list(reversed(module_name_space)):
submodule_obj = getattr(submodule_obj, each_name)
script_submodule = script_submodule._modules[each_name]
subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model)
subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name,
submodule_python_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
......@@ -362,13 +367,16 @@ class GraphConverter:
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
self.global_seq += 1
shared_node_name = build_full_name(submodule_full_name, '', self.global_seq)
shared_node_python_name = build_python_name(submodule_python_name, self.global_seq)
shared_type_operation = Operation.new('shared', {'reference': submodule_full_name})
subcell = ir_graph.add_node(shared_node_name, shared_type_operation)
subcell.python_name = shared_node_python_name
else:
# this module is processed for the first time, build cell for it
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
subcell.python_name = submodule_python_name
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, InputChoice):
......@@ -377,6 +385,7 @@ class GraphConverter:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
subcell.python_name = submodule_python_name
shared_module_index[submodule_full_name] = subcell
node_index[node] = subcell
# connect the cell into graph
......@@ -391,7 +400,7 @@ class GraphConverter:
# step #1: generate graph ir for this method
method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
self.handle_graph_nodes(script_module, script_method.graph, module,
module_name, ir_model, method_ir_graph, shared_module_index)
module_name, module_python_name, ir_model, method_ir_graph, shared_module_index)
self.refine_graph(method_ir_graph)
# step #2: merge this graph to its module graph
......@@ -439,6 +448,8 @@ class GraphConverter:
self.global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, self.global_seq),
'{}.{}'.format(func_type_str, func_name))
func_python_name = build_python_name(module_python_name, func_name)
func_node.python_name = func_python_name
node_index[node] = func_node
self._add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
......@@ -480,7 +491,10 @@ class GraphConverter:
# handle aten::XXX
self.global_seq += 1
aten_op_name = node.kind().replace('::', '__')
aten_op_python_name = node.kind().replace('aten::', '')
aten_node = ir_graph.add_node(build_full_name(module_name, aten_op_name, self.global_seq), node.kind())
aten_python_name = build_python_name(module_python_name, aten_op_python_name)
aten_node.python_name = aten_python_name
node_index[node] = aten_node
self._add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
else:
......@@ -587,25 +601,29 @@ class GraphConverter:
'accessor': module._accessor
}
def _convert_module(self, script_module, module, module_name, ir_model):
def _convert_module(self, script_module, module, module_name, module_python_name, ir_model):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
m_attrs = None
if original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
graph.python_name = module_python_name
candidate_name_list = []
for cand_name in module.names:
cand = module[cand_name]
script_cand = script_module._modules[cand_name]
cand_name = build_cand_name(cand_name, module.label)
candidate_name_list.append(cand_name)
subgraph, attrs = self._convert_module(script_cand, cand, cand_name, ir_model)
cand_full_name = build_cand_name(cand_name, module.label)
cand_python_name = build_python_name(module_python_name, cand_name)
candidate_name_list.append(cand_full_name)
subgraph, attrs = self._convert_module(script_cand, cand, cand_full_name, cand_python_name, ir_model)
if subgraph is not None:
graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs))
cand_node = graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs))
cand_node.python_name = cand_python_name
else:
cand_type = '__torch__.' + get_importable_name(cand.__class__)
graph.add_node(cand_name, cand_type, attrs)
cand_node = graph.add_node(cand_full_name, cand_type, attrs)
cand_node.python_name = cand_python_name
graph._register()
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
elif original_type_name == OpTypeName.InputChoice:
......@@ -629,10 +647,11 @@ class GraphConverter:
sm_graph = script_module.graph
self.global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
ir_graph.python_name = module_python_name
# handle graph nodes
self.handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph)
module_name, module_python_name, ir_model, ir_graph)
self.refine_graph(ir_graph)
ir_graph._register()
......@@ -671,8 +690,7 @@ class GraphConverter:
dict
the input arguments of this module
"""
return self._convert_module(script_module, module, module_name, ir_model)
return self._convert_module(script_module, module, module_name, None, ir_model)
class GraphConverterWithShape(GraphConverter):
......@@ -691,7 +709,7 @@ class GraphConverterWithShape(GraphConverter):
def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval()
ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model)
ir_graph, attrs = self._convert_module(script_module, module, module_name, None, ir_model)
self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, dummy_input)
......
......@@ -14,6 +14,15 @@ def build_full_name(prefix, name, seq=None):
return '{}__{}{}'.format(prefix, name, str(seq))
def build_python_name(prefix, name):
if isinstance(name, list):
name = '.'.join(name)
if prefix:
return '{}.{}'.format(prefix, name)
else: # predix could be None
return name
def build_cand_name(name, label):
return f'layerchoice_{label}_{name}'
......
......@@ -212,6 +212,20 @@ class Model:
else:
return None
def get_node_by_python_name(self, python_name: str) -> 'Node':
"""
Traverse all the nodes to find the matched node with the given python_name.
"""
matched_nodes = []
for graph in self.graphs.values():
nodes = graph.get_nodes_by_python_name(python_name)
matched_nodes.extend(nodes)
# assert len(matched_nodes) <= 1
if matched_nodes:
return matched_nodes[0]
else:
return None
def get_cell_nodes(self) -> List['Node']:
matched_nodes = []
for graph in self.graphs.values():
......@@ -274,6 +288,8 @@ class Graph:
All input/output/hidden nodes.
edges
...
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
......@@ -283,6 +299,9 @@ class Graph:
self.id: int = graph_id
self.name: str = name or f'_generated_{graph_id}'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self.python_name: Optional[str] = None
self.input_node: Node = Node(self, _InputPseudoUid, '_inputs', _IOPseudoOperation('_inputs'), _internal=True)
self.output_node: Node = Node(self, _OutputPseudoUid, '_outputs', _IOPseudoOperation('_outputs'), _internal=True)
self.hidden_nodes: List[Node] = []
......@@ -355,6 +374,13 @@ class Graph:
found = [node for node in self.nodes if node.name == name]
return found[0] if found else None
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
"""
Returns the node which has specified python_name; or returns `None` if no node has this python_name.
"""
found = [node for node in self.nodes if node.python_name == python_name]
return found[0] if found else None
def get_nodes_by_type(self, operation_type: str) -> List['Node']:
"""
Returns nodes whose operation is specified typed.
......@@ -374,6 +400,9 @@ class Graph:
def get_nodes_by_name(self, name: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.name == name]
def get_nodes_by_python_name(self, python_name: str) -> Optional['Node']:
return [node for node in self.nodes if node.python_name == python_name]
def topo_sort(self) -> List['Node']:
node_to_fanin = {}
curr_nodes = []
......@@ -423,9 +452,11 @@ class Graph:
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
new_graph.python_name = self.python_name
for node in self.hidden_nodes:
new_node = Node(new_graph, node.id, node.name, node.operation, _internal=True)
new_node.python_name = node.python_name
new_node.update_label(node.label)
new_node._register()
......@@ -446,11 +477,13 @@ class Graph:
new_graph.output_node.operation.io_names = self.output_node.operation.io_names
new_graph.input_node.update_label(self.input_node.label)
new_graph.output_node.update_label(self.output_node.label)
new_graph.python_name = self.python_name
id_to_new_node = {} # old node ID -> new node object
for old_node in self.hidden_nodes:
new_node = Node(new_graph, uid(), None, old_node.operation, _internal=True)._register()
new_node.python_name = old_node.python_name
new_node.update_label(old_node.label)
id_to_new_node[old_node.id] = new_node
......@@ -514,6 +547,8 @@ class Node:
If two models have nodes with same ID, they are semantically the same node.
name
Mnemonic name. It should have an one-to-one mapping with ID.
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
label
Optional. If two nodes have the same label, they are considered same by the mutator.
operation
......@@ -535,13 +570,15 @@ class Node:
self.graph: Graph = graph
self.id: int = node_id
self.name: str = name or f'_generated_{node_id}'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self.python_name: Optional[str] = None
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation
self.label: Optional[str] = None
def __repr__(self):
return f'Node(id={self.id}, name={self.name}, label={self.label}, operation={self.operation})'
return f'Node(id={self.id}, name={self.name}, python_name={self.python_name}, label={self.label}, operation={self.operation})'
@property
def predecessors(self) -> List['Node']:
......@@ -626,6 +663,8 @@ class Node:
ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None:
ret['label'] = self.label
if self.python_name is not None:
ret['python_name'] = self.python_name
return ret
......
......@@ -1232,5 +1232,33 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
x = torch.randn(5, 3, 2)
self.run_test(SizeModel(10, 5), (x, ))
def test_python_name(self):
from .inject_nn import inject_pytorch_nn, remove_inject_pytorch_nn
try:
inject_pytorch_nn()
torchvision_model_zoo = {
'resnet18': torchvision.models.resnet18(),
'alexnet': torchvision.models.alexnet(),
'vgg16': torchvision.models.vgg16(),
'squeezenet': torchvision.models.squeezenet1_0(),
'shufflenet_v2': torchvision.models.shufflenet_v2_x1_0(),
'mobilenet_v2': torchvision.models.mobilenet_v2(),
'resnext50_32x4d': torchvision.models.resnext50_32x4d(),
'wide_resnet50_2': torchvision.models.wide_resnet50_2(),
'mnasnet': torchvision.models.mnasnet1_0(),
}
dummy_input=torch.randn(1, 3, 224, 224)
for model in torchvision_model_zoo.values():
model_ir = self._convert_model(model, dummy_input)
current_name = [node.python_name for node in model_ir.get_nodes() if node.python_name]
mentioned = set()
for k in model.state_dict():
k = ".".join(k.split(".")[:-1])
if k not in mentioned:
assert k in current_name, f'{k} not in state_name'
mentioned.add(k)
finally:
remove_inject_pytorch_nn()
class TestPytorchWithShape(TestPytorch, ConvertWithShapeMixin):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment