Unverified Commit 5d2a59fd authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Successive unpack (#2768)

parent e7fccfb4
...@@ -530,8 +530,15 @@ class TorchModuleGraph(TorchGraph): ...@@ -530,8 +530,15 @@ class TorchModuleGraph(TorchGraph):
return True return True
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]: if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]:
# We cannot merge the List/Tuple # We cannot merge the List/Tuple
# Construct/Unpack func into other nodes, else it # Unpack func into other nodes, else it
# may lead to a graph construction error. # may lead to a graph construction error.
# The reason why we donnot take the construct node
# also as a key node is that `cat` operation node need
# the last(previous) visited node to infer the mask. If
# we take the Construct node as the important node, the
# predecessor of the `cat` node will always be a construct
# node, which means we cannot infer the mask for the cat
# operation.
return True return True
return False return False
...@@ -556,9 +563,13 @@ class TorchModuleGraph(TorchGraph): ...@@ -556,9 +563,13 @@ class TorchModuleGraph(TorchGraph):
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp)) _logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp))
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp)) _logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp))
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs())) assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs()))
for _input, _output in zip(last_cpp.inputs(), unpack_cpp.outputs()): errmsg = '%s Input number: %d if inconsistent with the output number %d' % (unpack_cpp, \
_debug_input = _input.debugName() len(node.inputs), len(list(last_cpp.inputs())))
_debug_output = _output.debugName()
assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg
for _debug_input, _debug_output in zip(node.inputs, node.outputs):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node: if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because # input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time. # one tensor can be used as input for multiple nodes at the same time.
...@@ -570,10 +581,13 @@ class TorchModuleGraph(TorchGraph): ...@@ -570,10 +581,13 @@ class TorchModuleGraph(TorchGraph):
self.input_to_node[_debug_input].remove(node) self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input] # add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output]) self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
if _debug_input in self.output_to_node and _debug_output in self.output_to_node: # just remove the _debug_output from the grapgh index. So that we can also skip
# output_to_node[_debug_output] is a NodePyGroup, because one output # the construct and tuple
# tensor only can be generated by one node. if _debug_output in self.input_to_node:
self.output_to_node[_debug_output] = self.output_to_node[_debug_input] for following_node in self.input_to_node[_debug_output]:
_tmp_index = following_node.inputs.index(_debug_output)
following_node.inputs[_tmp_index] = _debug_input
self.unpacked = True self.unpacked = True
......
...@@ -15,7 +15,7 @@ from google.protobuf import text_format ...@@ -15,7 +15,7 @@ from google.protobuf import text_format
import unittest import unittest
from unittest import TestCase, main from unittest import TestCase, main
from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph, TUPLE_UNPACK_KIND
class BackboneModel1(nn.Module): class BackboneModel1(nn.Module):
def __init__(self): def __init__(self):
...@@ -194,5 +194,101 @@ class GraphUtilsTestCase(TestCase): ...@@ -194,5 +194,101 @@ class GraphUtilsTestCase(TestCase):
assert len(nodes) == 1 assert len(nodes) == 1
node = nodes[0] node = nodes[0]
@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_module_unpack(self):
"""
test the tuple/list unpack function of TorchModuleGraph.
Following models are from the issue 2756
https://github.com/microsoft/nni/issues/2756.
MyModule will have two successive tuple unpack operations
between the B and C.
"""
class CBR(nn.Module):
def __init__(self, i, o):
super(CBR, self).__init__()
self.conv1 = nn.Conv2d(i, o, kernel_size=1)
self.bn1 = nn.BatchNorm2d(o)
self.act1 = nn.ReLU()
def forward(self, x):
return self.act1(self.bn1(self.conv1(x)))
class A(nn.Module):
def __init__(self):
super(A, self).__init__()
self.conv1 = CBR(3, 6, )
self.conv2 = CBR(6, 8, )
self.conv3 = CBR(6, 12)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x1)
return (x2, x3)
class B1(nn.Module):
def __init__(self):
super(B1, self).__init__()
self.conv1 = CBR(12, 32)
self.conv2 = CBR(32, 32)
self.conv3 = CBR(32, 32)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
return (x1, x2, x3)
class B(nn.Module):
def __init__(self):
super(B, self).__init__()
self.b = B1()
def forward(self, x):
return self.b(x[-1])
class C(nn.Module):
def __init__(self):
super(C, self).__init__()
self.conv1 = CBR(8, 32)
self.conv2 = CBR(12, 32)
self.conv3 = CBR(32, 32)
self.conv4 = CBR(32, 32)
self.conv5 = CBR(32, 32)
def forward(self, x):
return(self.conv1(x[0]), self.conv2(x[1]), self.conv3(x[2]),self.conv4(x[3]),self.conv5(x[4]))
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.a = A()
self.b = B()
# self.dummy = Dummy()
self.c = C()
def forward(self, x):
x_a = self.a(x)
x_b = self.b(x_a)
xc = self.c(x_a + x_b)
return xc
dummy_input = torch.rand(1, 3, 28, 28)
model = MyModule()
graph = TorchModuleGraph(model, dummy_input)
graph.unpack_manually()
for node in graph.nodes_py.nodes_op:
# The input of the function nodes should
# not come from the TupleUnpack node, because
# all the TupleUnpack nodes have been removed(unpacked)
# manually
for _input in node.inputs:
if _input in graph.output_to_node:
preprocessor = graph.output_to_node[_input]
assert preprocessor.op_type != TUPLE_UNPACK_KIND
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment