Unverified Commit 7ee5036b authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Bugfix issue2485 (#2524)

parent e1e1977c
......@@ -8,19 +8,21 @@ import re
from collections import defaultdict
import torch
from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy
CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr'
_logger = logging.getLogger(__name__)
def build_module_graph(model, dummy_input):
return TorchModuleGraph(model, dummy_input)
def build_graph(model, dummy_input, verbose=False):
g = TorchProtoGraph(model, dummy_input, verbose)
return g.graph_def, g.stepstats
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
......@@ -28,11 +30,13 @@ def parse_traced_name(module_name):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
class TorchGraph:
"""
This class is to extract pytorch model topology graph by tracing
"""
def __init__(self, model, dummy_input):
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
Parameters
----------
......@@ -40,25 +44,39 @@ class TorchGraph:
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
traced_model : torch._C.torch.jit.TopLevelTracedModule
An alredy traced model, if traced_model is not None, then TorchGraph will build the graph
based on this traced model and won't trace the model again.
"""
assert torch.__version__ >= '1.3.1'
# check if the input is legal
if traced_model is not None:
assert isinstance(traced_model, torch.jit.TopLevelTracedModule)
self.trace = traced_model
# it's ok if the graph is already unpacked
torch._C._jit_pass_inline(self.trace.graph)
elif model is not None and dummy_input is not None:
self.bound_model = model
self._trace(model, dummy_input)
else:
raise Exception(
'Please provide model & dummy_input or the traced_model as inputs')
def _trace(self, model, dummy_input):
with torch.onnx.set_training(model, False):
self.trace = torch.jit.trace(model, dummy_input)
torch._C._jit_pass_inline(self.trace.graph)
class TorchProtoGraph(TorchGraph):
"""
Generates model graph for pytorch models in protobuf, this implementation is borrowed from pytorch v1.4.0,
and fixed following issues:
Generates model graph for pytorch models in protobuf, this implementation
is borrowed from pytorch v1.4.0, and fixed following issues:
https://github.com/pytorch/pytorch/issues/33691
https://github.com/pytorch/pytorch/issues/33670
"""
def __init__(self, model, dummy_input, verbose=False):
super().__init__(model, dummy_input)
......@@ -70,8 +88,10 @@ class TorchProtoGraph(TorchGraph):
list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input)
if verbose:
print(self.trace.graph)
self.stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
self.graph_def = GraphDef(node=list_of_nodes, versions=VersionDef(producer=22))
self.stepstats = RunMetadata(step_stats=StepStats(
dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
self.graph_def = GraphDef(
node=list_of_nodes, versions=VersionDef(producer=22))
def parse(self, graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
......@@ -94,16 +114,20 @@ class TorchProtoGraph(TorchGraph):
nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict()
node_to_name = lambda d: str(d).split(":")[0].strip()
def node_to_name(d):
return str(d).split(":")[0].strip()
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
node_name = node_to_name(node)
parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
# If the parent node is not the top-level "self" node
if parent.kind() == GETATTR_KIND:
parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
attr_to_scope[node_name] = '{}/{}.{}'.format(
parent_scope, attr_scope, attr_name)
else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
......@@ -114,7 +138,8 @@ class TorchProtoGraph(TorchGraph):
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
# Create sink nodes for output ops
for i, node in enumerate(graph.outputs()):
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
......@@ -136,23 +161,33 @@ class TorchProtoGraph(TorchGraph):
node.scopeName = base_name
else:
module_name += '.' + alias
node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias)
node.scopeName += '/' + \
(alias_to_name[module_name]
if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
class NodePyGroup(NodePy):
"""
This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph,
there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=None):
def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None):
"""
Parameters:
-----------
name: str
node name, such as `conv1`, `backbone.classifier`
unique_name: str
A global unique name for current node. Due to some modules,
such as relu, may be reused several times, so the scopename
is not suitable as the global unique identifier, so we add a
unique_name for each node as the global unique identifier.
We should use the unique_name to traverset the module graph.
node_type: str
`module` or `func`
op_type: str
......@@ -167,6 +202,7 @@ class NodePyGroup(NodePy):
super(NodePyGroup, self).__init__(name, [])
self.node_cpps = node_cpps
self.name = name
self.unique_name = unique_name
self.op_type = op_type
self.type = node_type
self.nodes = []
......@@ -178,7 +214,7 @@ class NodePyGroup(NodePy):
def add_nodes(self, node_cpps):
for node_cpp in node_cpps:
nodepy = NodePyOP(node_cpp)
nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '')
nodepy.name = node_cpp.scopeName() + '_' + node_cpp.kind()
self.nodes.append(nodepy)
def sub_node_names(self):
......@@ -186,7 +222,8 @@ class NodePyGroup(NodePy):
def __repr__(self):
return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format(
self.name, self.type, self.op_type, self.sub_node_names(), self.inputs, self.outputs, self.auxiliary
self.name, self.type, self.op_type, self.sub_node_names(),
self.inputs, self.outputs, self.auxiliary
)
......@@ -194,12 +231,14 @@ class TorchModuleGraph(TorchGraph):
"""
Generates model graph, each node is created from single or multiple jit trace nodes.
"""
def __init__(self, model, dummy_input):
super().__init__(model, dummy_input)
def __init__(self, model=None, dummy_input=None, traced_model=None):
super().__init__(model, dummy_input, traced_model)
self.global_count = 0
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node,
module_type):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
......@@ -217,6 +256,8 @@ class TorchModuleGraph(TorchGraph):
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
module_type : str
can be 'module' or 'func'
Returns
-------
......@@ -224,11 +265,12 @@ class TorchModuleGraph(TorchGraph):
the expanded non-prim node
"""
# TODO: scope name could be empty
node_name = '.'.join([self._get_module_name(node.scopeName()), node.kind(), str(self.global_count)])
node_name = '.'.join([self._get_module_name(
node.scopeName()), node.kind(), str(self.global_count)])
unique_name = node_name
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
......@@ -249,28 +291,78 @@ class TorchModuleGraph(TorchGraph):
inputs.append(input_name)
for output in node.outputs():
outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, 'func', op_type, node_group, inputs=inputs, outputs=outputs)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs)
return nodepy
def _build_module_node_group(self, module_name, op_type, node_cpps, input_to_node, output_to_node):
graph = self.trace.graph
inputs, outputs = [], []
for n in node_cpps:
for i in n.inputs():
name = i.debugName()
if not name in output_to_node and i in graph.inputs():
inputs.append(name)
elif output_to_node[name] not in node_cpps:
inputs.append(name)
for o in n.outputs():
name = o.debugName()
if not name in input_to_node and o in graph.outputs():
outputs.append(name)
elif input_to_node[name] not in node_cpps:
outputs.append(name)
return NodePyGroup(module_name, 'module', op_type, node_cpps, inputs, outputs)
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
input_to_node, output_to_node, module_type):
"""
merge the adjacent nodes of the module. The difference between the
_expand_module_node and _expand_non_prim_node is that, the _expand_non_prim_node
only merge the prim:: nodes into the aten:: node, in contrast,the _expand_module_node
will merge all adjacent nodes into a same nodepy group.
Parameters
----------
node : trace graph node
The non-prim node to expand
node_name : str
specify the node_name for NodePyGroup
unique_name : str
unique_name for the NodePyGroup
op_type : str
specify the op_type for the NodePyGroup
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
module_type : str
can be 'module' or 'func'
Returns
-------
node
the expanded non-prim node
"""
_logger.debug("expand module node, node name: %s", node_name)
self.global_count += 1
if not op_type:
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
visited = {node}
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node not in visited:
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
visited.add(predecessor_node)
else:
inputs.append(input_name)
for _output in curr_node.outputs():
output_name = _output.debugName()
if output_name in input_to_node and input_to_node[output_name] in nodes:
successor_node = input_to_node[output_name]
if successor_node not in visited:
node_group.append(successor_node)
node_queue.put(successor_node)
visited.add(successor_node)
else:
outputs.append(output_name)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs)
return nodepy
def _extract_shape_info(self, node):
"""
......@@ -318,11 +410,12 @@ class TorchModuleGraph(TorchGraph):
parts1, parts2 = name1.split('.'), name2.split('.')
if len(parts1) >= len(parts2):
return False
for i in range(len(parts1)):
for i, _ in enumerate(parts1):
if parts2[i] != parts1[i]:
return False
return True
module_names = sorted([x[0] for x in self.trace.named_modules() if x[0]])
module_names = sorted([x[0]
for x in self.trace.named_modules() if x[0]])
leaf_nodes = []
for i, name in enumerate(module_names):
if i + 1 >= len(module_names) or not is_parent(name, module_names[i + 1]):
......@@ -354,7 +447,7 @@ class TorchModuleGraph(TorchGraph):
input_to_node = defaultdict(list)
output_to_node = dict()
for node in nodes_op:
name_to_node[node.name] = node
name_to_node[node.unique_name] = node
for _input in node.inputs:
input_to_node[_input].append(node)
for output in node.outputs:
......@@ -385,9 +478,11 @@ class TorchModuleGraph(TorchGraph):
graph = self.trace.graph
_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = {x.debugName(): n for n in graph.nodes() for x in n.outputs()}
output_to_node = {x.debugName(): n for n in graph.nodes()
for x in n.outputs()}
# build input mapping, from input debugName to its node
input_to_node = {x.debugName(): n for n in graph.nodes() for x in n.inputs()}
input_to_node = {x.debugName(): n for n in graph.nodes()
for x in n.inputs()}
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = defaultdict(list)
# the mapping of function (non-module in forward) to nodes, key is scope name
......@@ -403,7 +498,8 @@ class TorchModuleGraph(TorchGraph):
nodes_py.append(NodePyIO(node, 'input'))
self.leaf_modules = self._extract_leaf_modules()
module_to_type = {name: parse_traced_name(module._name) for name, module in self.trace.named_modules()}
module_to_type = {name: parse_traced_name(
module._name) for name, module in self.trace.named_modules()}
# associate module name with their trace graph nodes
for node in graph.nodes():
......@@ -412,14 +508,24 @@ class TorchModuleGraph(TorchGraph):
module_to_nodes[module_name].append(node)
else:
func_to_nodes[node.scopeName()].append(node)
# build node group for module
for module_name, node_cpps in module_to_nodes.items():
node_group = self._build_module_node_group(
module_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node
)
_logger.debug('node_group: %s', node_group)
use_count = 0
merged = set()
for node in node_cpps:
if node not in merged:
# modules that have same scope name may have different locations in the
# graph. Futhermore, there are also lots of prim:: nodes that in node_cpps,
# so we also need to call the expand_module_node.
unique_name = module_name
if use_count > 0:
unique_name = module_name + '.%d' % use_count
node_group = self._expand_module_node(
node, module_name, unique_name, module_to_type[module_name],
node_cpps, input_to_node, output_to_node, 'module')
nodes_py.nodes_op.append(node_group)
use_count += 1
merged.update(node_group.node_cpps)
# each scope_name may have multiple funcs, we split them and create node for each of them
# build node group for torch.nn.functional
......@@ -431,11 +537,13 @@ class TorchModuleGraph(TorchGraph):
non_prim_nodes.append(node)
# for each non prim node, expand it
for node in non_prim_nodes:
node_group = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node)
node_group = self._expand_non_prim_node(
node, nodes, input_to_node, output_to_node, 'func')
nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func
if node_group.op_type in ['aten::view', 'aten::flatten']:
node_group.auxiliary = self._extract_shape_info(node)
for node in graph.outputs(): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
nodes_py.append(node_py)
......@@ -444,14 +552,14 @@ class TorchModuleGraph(TorchGraph):
# build index
return self._build_index(self.nodes_py.nodes_op)
def find_predecessors(self, module_name):
def find_predecessors(self, unique_name):
"""
Find predecessor node of the given node
Parameters
----------
module_name : str
The name of the node
unique_name : str
The unique name of the node
Returns
-------
......@@ -459,22 +567,22 @@ class TorchModuleGraph(TorchGraph):
a list of nodes who are the given node's predecessor
"""
predecessors = []
for _input in self.name_to_node[module_name].inputs:
for _input in self.name_to_node[unique_name].inputs:
if not _input in self.output_to_node:
_logger.debug("cannot find node with %s as its output", _input)
else:
node_py = self.output_to_node[_input]
predecessors.append(node_py.name)
predecessors.append(node_py.unique_name)
return predecessors
def find_successors(self, module_name):
def find_successors(self, unique_name):
"""
Find successor nodes of the given node
Parameters
----------
module_name : str
The name of the node
unique_name : str
The unique name of the node
Returns
-------
......@@ -482,9 +590,11 @@ class TorchModuleGraph(TorchGraph):
a list of nodes who are the given node's successor
"""
successors = []
for output in self.name_to_node[module_name].outputs:
assert output in self.input_to_node, "No node with input {}".format(output)
for output in self.name_to_node[unique_name].outputs:
if output not in self.input_to_node:
# may reach the output of the whole graph
continue
nodes_py = self.input_to_node[output]
for node_py in nodes_py:
successors.append(node_py.name)
successors.append(node_py.unique_name)
return successors
......@@ -15,7 +15,7 @@ from google.protobuf import text_format
import unittest
from unittest import TestCase, main
from nni._graph_utils import build_module_graph, build_graph
from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph
class BackboneModel1(nn.Module):
def __init__(self):
......@@ -154,5 +154,45 @@ class GraphUtilsTestCase(TestCase):
os.path.join(os.path.dirname(__file__), "expect", "test_graph_module3.expect")
)
@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_module_reuse(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.liner1 = nn.Linear(10, 10)
self.relu = nn.ReLU(inplace=True)
self.liner2 = nn.Linear(10, 20)
self.liner3 = nn.Linear(20, 10)
def forward(self, x):
x = self.liner1(x)
x = self.relu(x)
x = self.liner2(x)
x = self.relu(x)
x = self.liner3(x)
x = self.relu(x)
return x
data = torch.rand(10, 10)
net = MyModule()
traced = torch.jit.trace(net, data)
modulegraph = TorchModuleGraph(traced_model=traced)
# Traverse the TorchModuleGraph, due the resue of the relu module,
# there will be three cpp_nodes corrspoding to the same module.
# During traversing the graph, there should be only one
# successor of each cpp-node (including the cpp_nodes that corresponds
# to the same relu module).
for name, nodeio in modulegraph.nodes_py.nodes_io.items():
if nodeio.input_or_output == 'input':
# Find the first node of the whole graph
start_nodes = modulegraph.input_to_node[name]
# We have only one single path top-down
assert len(start_nodes) == 1
node = start_nodes[0].unique_name
while modulegraph.find_successors(node):
nodes = modulegraph.find_successors(node)
assert len(nodes) == 1
node = nodes[0]
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment