Unverified Commit 5f0a7c9a authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Fix bug in graph converter from Graph_gen (#4092)

parent bf18854a
...@@ -491,6 +491,17 @@ class GraphConverter: ...@@ -491,6 +491,17 @@ class GraphConverter:
for node in sm_graph.nodes(): for node in sm_graph.nodes():
handle_single_node(node) handle_single_node(node)
if node_index == {}:
# here is an example that the ir_graph is empty
# graph(%self : __torch__.torchmodels.googlenet.GoogLeNet,
# %x.1 : Tensor): return (%x.1)
# add a noop_identity node to handle this situation
self.global_seq += 1
ni_node = ir_graph.add_node(build_full_name(module_name, 'noop_identity', self.global_seq), 'noop_identity')
ir_graph.add_edge(head=(ir_graph.input_node, 0), tail=(ni_node, None))
ir_graph.add_edge(head=(ni_node, None), tail=(ir_graph.output_node, None))
for _output in sm_graph.outputs():
node_index[_output.node()] = ni_node
return node_index return node_index
def merge_aten_slices(self, ir_graph): def merge_aten_slices(self, ir_graph):
...@@ -575,9 +586,7 @@ class GraphConverter: ...@@ -575,9 +586,7 @@ class GraphConverter:
# also has LayerChoice or InputChoice or ValueChoice # also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name original_type_name = script_module.original_name
m_attrs = None m_attrs = None
if original_type_name in MODULE_EXCEPT_LIST: if original_type_name == OpTypeName.LayerChoice:
pass # do nothing
elif original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
candidate_name_list = [] candidate_name_list = []
for cand_name in module.names: for cand_name in module.names:
...@@ -599,7 +608,9 @@ class GraphConverter: ...@@ -599,7 +608,9 @@ class GraphConverter:
m_attrs = self._handle_valuechoice(module) m_attrs = self._handle_valuechoice(module)
elif original_type_name == OpTypeName.Placeholder: elif original_type_name == OpTypeName.Placeholder:
m_attrs = get_init_parameters_or_fail(module) m_attrs = get_init_parameters_or_fail(module)
elif module.__class__.__module__.startswith('torch.nn') and original_type_name in torch.nn.__dict__: elif module.__class__.__module__.startswith('torch.nn') and \
original_type_name in torch.nn.__dict__ and \
original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph # this is a basic module from pytorch, no need to parse its graph
m_attrs = get_init_parameters_or_fail(module) m_attrs = get_init_parameters_or_fail(module)
elif getattr(module, '_stop_parsing', False): elif getattr(module, '_stop_parsing', False):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from enum import Enum from enum import Enum
# except the special case which can not treat as a basic module from pytorch
MODULE_EXCEPT_LIST = ['Sequential'] MODULE_EXCEPT_LIST = ['Sequential']
......
...@@ -41,7 +41,7 @@ if version_larger_equal(torch.__version__, '1.7.0'): ...@@ -41,7 +41,7 @@ if version_larger_equal(torch.__version__, '1.7.0'):
Module = nn.Module Module = nn.Module
Sequential = transparent_serialize(nn.Sequential) Sequential = nn.Sequential
ModuleList = transparent_serialize(nn.ModuleList) ModuleList = transparent_serialize(nn.ModuleList)
Identity = basic_unit(nn.Identity) Identity = basic_unit(nn.Identity)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment