Unverified Commit 5946b4a4 authored by ZHANG Zhi's avatar ZHANG Zhi Committed by GitHub
Browse files

Add to permit (#3357)

parent 06e438b7
...@@ -315,6 +315,11 @@ class GraphConverter: ...@@ -315,6 +315,11 @@ class GraphConverter:
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, self.global_seq), node.kind()) new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, self.global_seq), node.kind())
node_index[node] = new_node node_index[node] = new_node
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap) self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'prim::TupleConstruct':
self.global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.TupleConstruct, self.global_seq), node.kind())
node_index[node] = new_node
self._add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append': elif node.kind() == 'aten::append':
self.global_seq += 1 self.global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind()) aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], self.global_seq), node.kind())
......
...@@ -10,6 +10,7 @@ class OpTypeName(str, Enum): ...@@ -10,6 +10,7 @@ class OpTypeName(str, Enum):
Attr = 'Attr' Attr = 'Attr'
Constant = 'Constant' Constant = 'Constant'
ListConstruct = 'ListConstruct' ListConstruct = 'ListConstruct'
TupleConstruct = 'TupleConstruct'
LayerChoice = 'LayerChoice' LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice' InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice' ValueChoice = 'ValueChoice'
......
...@@ -121,6 +121,8 @@ class PyTorchOperation(Operation): ...@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
return f'{output} = {value}' return f'{output} = {value}'
elif self.type == 'prim::ListConstruct': elif self.type == 'prim::ListConstruct':
return f'{output} = [{", ".join(inputs)}]' return f'{output} = [{", ".join(inputs)}]'
elif self.type == 'prim::TupleConstruct':
return f'{output} = ({", ".join(inputs)})'
elif self.type == 'prim::GetAttr': elif self.type == 'prim::GetAttr':
return f"{output} = {self.parameters['input']}.{self.parameters['name']}" return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
elif self.type == 'aten::mean': elif self.type == 'aten::mean':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment