Unverified Commit 06e438b7 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

support the scenario that there are duplicate tensors in a same tuple (#3340)

parent 7d6b8b3b
...@@ -285,8 +285,8 @@ class TorchModuleGraph(TorchGraph): ...@@ -285,8 +285,8 @@ class TorchModuleGraph(TorchGraph):
self.global_count += 1 self.global_count += 1
op_type = node.kind() op_type = node.kind()
node_group = [node] node_group = [node]
inputs = set() inputs = []
outputs = set() outputs = []
node_queue = queue.Queue() node_queue = queue.Queue()
node_queue.put(node) node_queue.put(node)
while not node_queue.empty(): while not node_queue.empty():
...@@ -303,17 +303,17 @@ class TorchModuleGraph(TorchGraph): ...@@ -303,17 +303,17 @@ class TorchModuleGraph(TorchGraph):
node_group.append(predecessor_node) node_group.append(predecessor_node)
node_queue.put(predecessor_node) node_queue.put(predecessor_node)
else: else:
inputs.add(input_name) inputs.append(input_name)
else: else:
inputs.add(input_name) inputs.append(input_name)
else: else:
inputs.add(input_name) inputs.append(input_name)
for output in node.outputs(): for output in node.outputs():
if output.node().kind() == CONSTANT_KIND: if output.node().kind() == CONSTANT_KIND:
continue continue
outputs.add(output.debugName()) outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=list(inputs), outputs=list(outputs), key_node=node) node_group, inputs=inputs, outputs=outputs, key_node=node)
return nodepy return nodepy
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
...@@ -353,8 +353,8 @@ class TorchModuleGraph(TorchGraph): ...@@ -353,8 +353,8 @@ class TorchModuleGraph(TorchGraph):
if not op_type: if not op_type:
op_type = node.kind() op_type = node.kind()
node_group = [node] node_group = [node]
inputs = set() inputs = []
outputs = set() outputs = []
node_queue = queue.Queue() node_queue = queue.Queue()
node_queue.put(node) node_queue.put(node)
visited = {node} visited = {node}
...@@ -372,9 +372,9 @@ class TorchModuleGraph(TorchGraph): ...@@ -372,9 +372,9 @@ class TorchModuleGraph(TorchGraph):
node_queue.put(predecessor_node) node_queue.put(predecessor_node)
visited.add(predecessor_node) visited.add(predecessor_node)
else: else:
inputs.add(input_name) inputs.append(input_name)
else: else:
inputs.add(input_name) inputs.append(input_name)
for _output in curr_node.outputs(): for _output in curr_node.outputs():
if _output.node().kind() == CONSTANT_KIND: if _output.node().kind() == CONSTANT_KIND:
continue continue
...@@ -387,9 +387,9 @@ class TorchModuleGraph(TorchGraph): ...@@ -387,9 +387,9 @@ class TorchModuleGraph(TorchGraph):
node_queue.put(successor_node) node_queue.put(successor_node)
visited.add(successor_node) visited.add(successor_node)
else: else:
outputs.add(output_name) outputs.append(output_name)
else: else:
outputs.add(output_name) outputs.append(output_name)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=list(inputs), outputs=list(outputs)) node_group, inputs=list(inputs), outputs=list(outputs))
...@@ -562,9 +562,12 @@ class TorchModuleGraph(TorchGraph): ...@@ -562,9 +562,12 @@ class TorchModuleGraph(TorchGraph):
for node in nodes_op: for node in nodes_op:
name_to_node[node.unique_name] = node name_to_node[node.unique_name] = node
for _input in node.inputs: for _input in node.inputs:
# inputs may have duplicate tensors
if node not in input_to_node[_input]:
input_to_node[_input].append(node) input_to_node[_input].append(node)
for output in node.outputs: for output in node.outputs:
assert not output in output_to_node, \ if output in output_to_node:
assert output_to_node[output] == node, \
"One output cannot be generated by multiple nodes %s" % output "One output cannot be generated by multiple nodes %s" % output
output_to_node[output] = node output_to_node[output] = node
return name_to_node, input_to_node, output_to_node return name_to_node, input_to_node, output_to_node
...@@ -619,8 +622,6 @@ class TorchModuleGraph(TorchGraph): ...@@ -619,8 +622,6 @@ class TorchModuleGraph(TorchGraph):
assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg
for _debug_input, _debug_output in zip(node.inputs, node.outputs): for _debug_input, _debug_output in zip(node.inputs, node.outputs):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node: if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because # input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time. # one tensor can be used as input for multiple nodes at the same time.
...@@ -629,6 +630,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -629,6 +630,7 @@ class TorchModuleGraph(TorchGraph):
# will be merged into the same NodePyGroup, so we remove the `node` from # will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the # input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output] # input_to_node[_debug_output]
if node in self.input_to_node[_debug_input]:
self.input_to_node[_debug_input].remove(node) self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input] # add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output]) self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment