Unverified Commit f8633ac9 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Support the List/Tuple Construct/Unpack operation for TorchModuleGraph (#2609)

parent 66f2777f
......@@ -11,6 +11,10 @@ from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, G
CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr'
CAT_KIND = 'aten::cat'
LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
LIST_UNPACK_KIND = 'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct'
TUPLE_UNPACK_KIND = 'prim::TupleUnpack'
_logger = logging.getLogger(__name__)
......@@ -177,7 +181,7 @@ class NodePyGroup(NodePy):
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None):
def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None, key_node=None):
"""
Parameters:
-----------
......@@ -199,6 +203,8 @@ class NodePyGroup(NodePy):
All the inputs of this node, each element is debugName of one input
outputs: list of str
All the outputs of this node, each element is debugName of one output
key_node: torch._C.Node
The key node of this NodePyGroup.
"""
super(NodePyGroup, self).__init__(name, [])
self.node_cpps = node_cpps
......@@ -211,6 +217,8 @@ class NodePyGroup(NodePy):
self.add_nodes(node_cpps)
self.inputs = inputs
self.outputs = outputs
# The core node in this NodePyGroup
self.key_node = key_node
def add_nodes(self, node_cpps):
for node_cpp in node_cpps:
......@@ -239,7 +247,7 @@ class TorchModuleGraph(TorchGraph):
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
self._extract_auxiliary_info()
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node,
def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node,
module_type):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
......@@ -284,7 +292,7 @@ class TorchModuleGraph(TorchGraph):
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node.kind().startswith('prim::'):
if not self._is_key_func(predecessor_node):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
......@@ -294,7 +302,7 @@ class TorchModuleGraph(TorchGraph):
for output in node.outputs():
outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs)
node_group, inputs=inputs, outputs=outputs, key_node=node)
return nodepy
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
......@@ -510,6 +518,65 @@ class TorchModuleGraph(TorchGraph):
output_to_node[output] = node
return name_to_node, input_to_node, output_to_node
def _is_key_func(self, node_cpp):
"""
Judge if a cpp node is a key function node.
If so, we should not merge this node into the
adjacent node.
"""
if node_cpp.kind().startswith('aten::'):
# the nodes that start with 'aten' are key function
# nodes
return True
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]:
# We cannot merge the List/Tuple
# Construct/Unpack func into other nodes, else it
# may lead to a graph construction error.
return True
return False
def unpack_manually(self):
"""
Unpack the tensor tuple or tensor list manually,
and remove the ListUnpack/TupleUnpack node from
the graph. Note: this function will change the
graph structure.
"""
if hasattr(self, 'unpacked'):
# if already unpacked the tuple/list manually
return
for node in self.nodes_py.nodes_op:
if node.op_type in [TUPLE_UNPACK_KIND, LIST_UNPACK_KIND]:
unpack_cpp = node.key_node
last_cpp = list(unpack_cpp.inputs())[0].node()
if last_cpp.kind() in [TUPLE_CONSTRUCT_KIND, LIST_CONSTRUCT_KIND]:
# we need check if the tensor tuple or tensor list is produced
# by a list/tuple construct node. If so, we can unpack the tuple
# or list manunally.
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp))
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp))
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs()))
for _input, _output in zip(last_cpp.inputs(), unpack_cpp.outputs()):
_debug_input = _input.debugName()
_debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
# note that, in this case, the construct cpp node and unpack cpp node
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
if _debug_input in self.output_to_node and _debug_output in self.output_to_node:
# output_to_node[_debug_output] is a NodePyGroup, because one output
# tensor only can be generated by one node.
self.output_to_node[_debug_output] = self.output_to_node[_debug_input]
self.unpacked = True
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
......@@ -585,13 +652,14 @@ class TorchModuleGraph(TorchGraph):
# build node group for torch.nn.functional
for _, nodes in func_to_nodes.items():
# extract non prim:: nodes
non_prim_nodes = list()
key_func_nodes = list()
for node in nodes:
if not node.kind().startswith('prim::'):
non_prim_nodes.append(node)
if self._is_key_func(node):
# find the key function nodes
key_func_nodes.append(node)
# for each non prim node, expand it
for node in non_prim_nodes:
node_group = self._expand_non_prim_node(
for node in key_func_nodes:
node_group = self._expand_key_func_node(
node, nodes, input_to_node, output_to_node, 'func')
nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func
......
......@@ -86,6 +86,9 @@ class ChannelDependency(Dependency):
Build the channel dependency for the conv layers
in the model.
"""
# unpack the tuple/list manually before analyze the
# channel dependency
self.graph.unpack_manually()
for node in self.graph.nodes_py.nodes_op:
parent_layers = []
# find the node that contains aten::add
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment