Unverified Commit f8633ac9 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Support the List/Tuple Construct/Unpack operation for TorchModuleGraph (#2609)

parent 66f2777f
...@@ -11,6 +11,10 @@ from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, G ...@@ -11,6 +11,10 @@ from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, G
CLASSTYPE_KIND = 'ClassType' CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr' GETATTR_KIND = 'prim::GetAttr'
CAT_KIND = 'aten::cat' CAT_KIND = 'aten::cat'
LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
LIST_UNPACK_KIND = 'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct'
TUPLE_UNPACK_KIND = 'prim::TupleUnpack'
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -177,7 +181,7 @@ class NodePyGroup(NodePy): ...@@ -177,7 +181,7 @@ class NodePyGroup(NodePy):
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node. represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
""" """
def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None): def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None, key_node=None):
""" """
Parameters: Parameters:
----------- -----------
...@@ -199,6 +203,8 @@ class NodePyGroup(NodePy): ...@@ -199,6 +203,8 @@ class NodePyGroup(NodePy):
All the inputs of this node, each element is debugName of one input All the inputs of this node, each element is debugName of one input
outputs: list of str outputs: list of str
All the outputs of this node, each element is debugName of one output All the outputs of this node, each element is debugName of one output
key_node: torch._C.Node
The key node of this NodePyGroup.
""" """
super(NodePyGroup, self).__init__(name, []) super(NodePyGroup, self).__init__(name, [])
self.node_cpps = node_cpps self.node_cpps = node_cpps
...@@ -211,6 +217,8 @@ class NodePyGroup(NodePy): ...@@ -211,6 +217,8 @@ class NodePyGroup(NodePy):
self.add_nodes(node_cpps) self.add_nodes(node_cpps)
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
# The core node in this NodePyGroup
self.key_node = key_node
def add_nodes(self, node_cpps): def add_nodes(self, node_cpps):
for node_cpp in node_cpps: for node_cpp in node_cpps:
...@@ -239,7 +247,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -239,7 +247,7 @@ class TorchModuleGraph(TorchGraph):
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
self._extract_auxiliary_info() self._extract_auxiliary_info()
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node,
module_type): module_type):
""" """
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
...@@ -284,7 +292,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -284,7 +292,7 @@ class TorchModuleGraph(TorchGraph):
input_name = _input.debugName() input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes: if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name] predecessor_node = output_to_node[input_name]
if predecessor_node.kind().startswith('prim::'): if not self._is_key_func(predecessor_node):
node_group.append(predecessor_node) node_group.append(predecessor_node)
node_queue.put(predecessor_node) node_queue.put(predecessor_node)
else: else:
...@@ -294,7 +302,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -294,7 +302,7 @@ class TorchModuleGraph(TorchGraph):
for output in node.outputs(): for output in node.outputs():
outputs.append(output.debugName()) outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs) node_group, inputs=inputs, outputs=outputs, key_node=node)
return nodepy return nodepy
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
...@@ -510,6 +518,65 @@ class TorchModuleGraph(TorchGraph): ...@@ -510,6 +518,65 @@ class TorchModuleGraph(TorchGraph):
output_to_node[output] = node output_to_node[output] = node
return name_to_node, input_to_node, output_to_node return name_to_node, input_to_node, output_to_node
def _is_key_func(self, node_cpp):
"""
Judge if a cpp node is a key function node.
If so, we should not merge this node into the
adjacent node.
"""
if node_cpp.kind().startswith('aten::'):
# the nodes that start with 'aten' are key function
# nodes
return True
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]:
# We cannot merge the List/Tuple
# Construct/Unpack func into other nodes, else it
# may lead to a graph construction error.
return True
return False
def unpack_manually(self):
"""
Unpack the tensor tuple or tensor list manually,
and remove the ListUnpack/TupleUnpack node from
the graph. Note: this function will change the
graph structure.
"""
if hasattr(self, 'unpacked'):
# if already unpacked the tuple/list manually
return
for node in self.nodes_py.nodes_op:
if node.op_type in [TUPLE_UNPACK_KIND, LIST_UNPACK_KIND]:
unpack_cpp = node.key_node
last_cpp = list(unpack_cpp.inputs())[0].node()
if last_cpp.kind() in [TUPLE_CONSTRUCT_KIND, LIST_CONSTRUCT_KIND]:
# we need check if the tensor tuple or tensor list is produced
# by a list/tuple construct node. If so, we can unpack the tuple
# or list manunally.
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp))
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp))
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs()))
for _input, _output in zip(last_cpp.inputs(), unpack_cpp.outputs()):
_debug_input = _input.debugName()
_debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
# note that, in this case, the construct cpp node and unpack cpp node
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
if _debug_input in self.output_to_node and _debug_output in self.output_to_node:
# output_to_node[_debug_output] is a NodePyGroup, because one output
# tensor only can be generated by one node.
self.output_to_node[_debug_output] = self.output_to_node[_debug_input]
self.unpacked = True
def _build_graph(self): def _build_graph(self):
""" """
Build graph using our defined format from jit trace. Build graph using our defined format from jit trace.
...@@ -585,13 +652,14 @@ class TorchModuleGraph(TorchGraph): ...@@ -585,13 +652,14 @@ class TorchModuleGraph(TorchGraph):
# build node group for torch.nn.functional # build node group for torch.nn.functional
for _, nodes in func_to_nodes.items(): for _, nodes in func_to_nodes.items():
# extract non prim:: nodes # extract non prim:: nodes
non_prim_nodes = list() key_func_nodes = list()
for node in nodes: for node in nodes:
if not node.kind().startswith('prim::'): if self._is_key_func(node):
non_prim_nodes.append(node) # find the key function nodes
key_func_nodes.append(node)
# for each non prim node, expand it # for each non prim node, expand it
for node in non_prim_nodes: for node in key_func_nodes:
node_group = self._expand_non_prim_node( node_group = self._expand_key_func_node(
node, nodes, input_to_node, output_to_node, 'func') node, nodes, input_to_node, output_to_node, 'func')
nodes_py.nodes_op.append(node_group) nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func # get shape infor for view (aten::view) func
......
...@@ -86,6 +86,9 @@ class ChannelDependency(Dependency): ...@@ -86,6 +86,9 @@ class ChannelDependency(Dependency):
Build the channel dependency for the conv layers Build the channel dependency for the conv layers
in the model. in the model.
""" """
# unpack the tuple/list manually before analyze the
# channel dependency
self.graph.unpack_manually()
for node in self.graph.nodes_py.nodes_op: for node in self.graph.nodes_py.nodes_op:
parent_layers = [] parent_layers = []
# find the node that contains aten::add # find the node that contains aten::add
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment