Unverified Commit 5f0a7c9a authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Fix bug in graph converter from Graph_gen (#4092)

parent bf18854a
......@@ -491,6 +491,17 @@ class GraphConverter:
for node in sm_graph.nodes():
handle_single_node(node)
if node_index == {}:
# here is an example that the ir_graph is empty
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
# %x.1 : Tensor): return (%x.1)
# add a noop_identity node to handle this situation
self.global_seq += 1
ni_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ni_node, None))
ir_graph.add_edge(head=(ni_node, None), tail=(ir_graph.output_node, None))
for _output in sm_graph.outputs():
node_index[_output.node()] = ni_node
return node_index
def merge_aten_slices(self, ir_graph):
......@@ -575,9 +586,7 @@ class GraphConverter:
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
m_attrs = None
if original_type_name in MODULE_EXCEPT_LIST:
pass # do nothing
elif original_type_name == OpTypeName.LayerChoice:
if original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
candidate_name_list = []
for cand_name in module.names:
......@@ -599,7 +608,9 @@ class GraphConverter:
m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder:
m_attrs = get_init_parameters_or_fail(module)
elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__:
elif module.__class__.__module__.startswith('torch.nn') and \
original_type_name in torch.nn.__dict__ and \
original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_stop_parsing', False):
......
......@@ -3,6 +3,7 @@
from enum import Enum
# except the special case which can not treat as a basic module from pytorch
MODULE_EXCEPT_LIST = ['Sequential']
......
......@@ -41,7 +41,7 @@ if version_larger_equal(torch.__version__, '1.7.0'):
Module = nn.Module
Sequential = transparent_serialize(nn.Sequential)
Sequential = nn.Sequential
ModuleList = transparent_serialize(nn.ModuleList)
Identity = basic_unit(nn.Identity)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment