Unverified Commit 52c2d4d3 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

[Retiarii] Fix bug in codegen for identity node (#4263)

parent 3cc2df58
...@@ -381,17 +381,8 @@ class GraphConverter: ...@@ -381,17 +381,8 @@ class GraphConverter:
# step #1: generate graph ir for this method # step #1: generate graph ir for this method
method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True) method_ir_graph = Graph(model=ir_model, graph_id=-100, name='temp_graph', _internal=True)
method_node_index = self.handle_graph_nodes(script_module, script_method.graph, module, self.handle_graph_nodes(script_module, script_method.graph, module,
module_name, ir_model, method_ir_graph, shared_module_index) module_name, ir_model, method_ir_graph, shared_module_index)
for _output in script_method.graph.outputs():
method_ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
method_ir_graph.add_edge(head=(method_node_index[_output.node()], src_node_idx),
tail=(method_ir_graph.output_node, None))
self.refine_graph(method_ir_graph) self.refine_graph(method_ir_graph)
# step #2: merge this graph to its module graph # step #2: merge this graph to its module graph
...@@ -491,18 +482,24 @@ class GraphConverter: ...@@ -491,18 +482,24 @@ class GraphConverter:
for node in sm_graph.nodes(): for node in sm_graph.nodes():
handle_single_node(node) handle_single_node(node)
if node_index == {}: if node_index != {}:
# here is an example that the ir_graph is empty for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
else:
# here is an example that the ir_graph and node_index is empty
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet, # graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
# %x.1 : Tensor): return (%x.1) # %x.1 : Tensor): return (%x.1)
# add a noop_identity node to handle this situation # add an edge from head to tail to handle this situation
self.global_seq += 1 ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ir_graph.output_node, None))
ni_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ni_node, None))
ir_graph.add_edge(head=(ni_node, None), tail=(ir_graph.output_node, None))
for _output in sm_graph.outputs():
node_index[_output.node()] = ni_node
return node_index
def merge_aten_slices(self, ir_graph): def merge_aten_slices(self, ir_graph):
""" """
...@@ -625,20 +622,8 @@ class GraphConverter: ...@@ -625,20 +622,8 @@ class GraphConverter:
ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True) ir_graph = Graph(model=ir_model, graph_id=self.global_graph_id, name=module_name, _internal=True)
# handle graph nodes # handle graph nodes
node_index = self.handle_graph_nodes(script_module, sm_graph, module, self.handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph) module_name, ir_model, ir_graph)
# handle graph outputs
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
self.refine_graph(ir_graph) self.refine_graph(ir_graph)
ir_graph._register() ir_graph._register()
...@@ -690,7 +675,7 @@ class GraphConverterWithShape(GraphConverter): ...@@ -690,7 +675,7 @@ class GraphConverterWithShape(GraphConverter):
Known issues Known issues
------------ ------------
1. `InputChoice` and `ValueChoice` not supported yet. 1. `InputChoice` and `ValueChoice` not supported yet.
2. Currently random inputs are feeded while tracing layerchoice. 2. Currently random inputs are fed while tracing layerchoice.
If forward path of candidates depends on input data, then wrong path will be traced. If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info. This will result in incomplete shape info.
""" """
......
...@@ -90,5 +90,35 @@ class TestModels(unittest.TestCase, ConvertMixin): ...@@ -90,5 +90,35 @@ class TestModels(unittest.TestCase, ConvertMixin):
x = torch.rand((1, 16), dtype=torch.float) x = torch.rand((1, 16), dtype=torch.float)
self.run_test(model, ([x], )) self.run_test(model, ([x], ))
def test_identity_node(self):
class Net(nn.Module):
def forward(self, x):
return x
model = Net()
x = torch.rand((1, 64, 224, 224), dtype=torch.float)
self.run_test(model, (x, ))
def test_nn_sequential_inherit(self):
class ConvBNReLU(nn.Sequential):
def __init__(self):
super().__init__(
nn.Conv2d(3, 3, 1, 1, bias=False),
nn.BatchNorm2d(3),
nn.ReLU(inplace=False)
)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv_bn_relu = ConvBNReLU()
def forward(self, x):
return self.conv_bn_relu(x)
model = Net()
x = torch.rand((1, 3, 224, 224), dtype=torch.float)
self.run_test(model, (x, ))
class TestModelsWithShape(TestModels, ConvertWithShapeMixin): class TestModelsWithShape(TestModels, ConvertWithShapeMixin):
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment