Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d50b4665
Unverified
Commit
d50b4665
authored
Nov 29, 2021
by
Jiahang Xu
Committed by
GitHub
Nov 29, 2021
Browse files
Add python name as Node attribute of graph_gen (#4243)
parent
068775f3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
110 additions
and
16 deletions
+110
-16
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+33
-15
nni/retiarii/converter/utils.py
nni/retiarii/converter/utils.py
+9
-0
nni/retiarii/graph.py
nni/retiarii/graph.py
+40
-1
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+28
-0
No files found.
nni/retiarii/converter/graph_gen.py
View file @
d50b4665
...
@@ -14,7 +14,8 @@ from .op_types import MODULE_EXCEPT_LIST, OpTypeName
...
@@ -14,7 +14,8 @@ from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from
.utils
import
(
from
.utils
import
(
_convert_name
,
build_full_name
,
_without_shape_info
,
_convert_name
,
build_full_name
,
_without_shape_info
,
_extract_info_from_trace_node
,
get_full_name_by_scope_name
,
_extract_info_from_trace_node
,
get_full_name_by_scope_name
,
is_layerchoice_node
,
match_node
,
build_cand_name
is_layerchoice_node
,
match_node
,
build_cand_name
,
build_python_name
)
)
...
@@ -139,7 +140,7 @@ class GraphConverter:
...
@@ -139,7 +140,7 @@ class GraphConverter:
hidden_node
.
remove
()
hidden_node
.
remove
()
def
handle_graph_nodes
(
self
,
script_module
,
sm_graph
,
def
handle_graph_nodes
(
self
,
script_module
,
sm_graph
,
module
,
module_name
,
module
,
module_name
,
module_python_name
,
ir_model
,
ir_graph
,
ir_model
,
ir_graph
,
shared_module_index
=
None
):
shared_module_index
=
None
):
"""
"""
...
@@ -317,10 +318,12 @@ class GraphConverter:
...
@@ -317,10 +318,12 @@ class GraphConverter:
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_python_name
=
build_python_name
(
module_python_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
_convert_module
(
script_module
.
_modules
[
submodule_name
],
subgraph
,
sub_m_attrs
=
self
.
_convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_obj
,
submodule_full_name
,
ir_model
)
submodule_full_name
,
submodule_python_name
,
ir_model
)
else
:
else
:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
...
@@ -347,12 +350,14 @@ class GraphConverter:
...
@@ -347,12 +350,14 @@ class GraphConverter:
assert
predecessor
.
hasAttribute
(
'name'
)
assert
predecessor
.
hasAttribute
(
'name'
)
module_name_space
.
append
(
predecessor
.
s
(
'name'
))
module_name_space
.
append
(
predecessor
.
s
(
'name'
))
submodule_full_name
=
build_full_name
(
module_name
,
list
(
reversed
(
module_name_space
)))
submodule_full_name
=
build_full_name
(
module_name
,
list
(
reversed
(
module_name_space
)))
submodule_python_name
=
build_python_name
(
module_python_name
,
list
(
reversed
(
module_name_space
)))
submodule_obj
=
module
submodule_obj
=
module
script_submodule
=
script_module
script_submodule
=
script_module
for
each_name
in
list
(
reversed
(
module_name_space
)):
for
each_name
in
list
(
reversed
(
module_name_space
)):
submodule_obj
=
getattr
(
submodule_obj
,
each_name
)
submodule_obj
=
getattr
(
submodule_obj
,
each_name
)
script_submodule
=
script_submodule
.
_modules
[
each_name
]
script_submodule
=
script_submodule
.
_modules
[
each_name
]
subgraph
,
sub_m_attrs
=
self
.
_convert_module
(
script_submodule
,
submodule_obj
,
submodule_full_name
,
ir_model
)
subgraph
,
sub_m_attrs
=
self
.
_convert_module
(
script_submodule
,
submodule_obj
,
submodule_full_name
,
submodule_python_name
,
ir_model
)
else
:
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
...
@@ -362,13 +367,16 @@ class GraphConverter:
...
@@ -362,13 +367,16 @@ class GraphConverter:
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
self
.
global_seq
+=
1
self
.
global_seq
+=
1
shared_node_name
=
build_full_name
(
submodule_full_name
,
''
,
self
.
global_seq
)
shared_node_name
=
build_full_name
(
submodule_full_name
,
''
,
self
.
global_seq
)
shared_node_python_name
=
build_python_name
(
submodule_python_name
,
self
.
global_seq
)
shared_type_operation
=
Operation
.
new
(
'shared'
,
{
'reference'
:
submodule_full_name
})
shared_type_operation
=
Operation
.
new
(
'shared'
,
{
'reference'
:
submodule_full_name
})
subcell
=
ir_graph
.
add_node
(
shared_node_name
,
shared_type_operation
)
subcell
=
ir_graph
.
add_node
(
shared_node_name
,
shared_type_operation
)
subcell
.
python_name
=
shared_node_python_name
else
:
else
:
# this module is processed for the first time, build cell for it
# this module is processed for the first time, build cell for it
if
subgraph
is
None
:
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
subcell
.
python_name
=
submodule_python_name
if
isinstance
(
submodule_obj
,
Placeholder
):
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
InputChoice
):
elif
isinstance
(
submodule_obj
,
InputChoice
):
...
@@ -377,6 +385,7 @@ class GraphConverter:
...
@@ -377,6 +385,7 @@ class GraphConverter:
# Graph already created, create Cell for it
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
subcell
.
python_name
=
submodule_python_name
shared_module_index
[
submodule_full_name
]
=
subcell
shared_module_index
[
submodule_full_name
]
=
subcell
node_index
[
node
]
=
subcell
node_index
[
node
]
=
subcell
# connect the cell into graph
# connect the cell into graph
...
@@ -391,7 +400,7 @@ class GraphConverter:
...
@@ -391,7 +400,7 @@ class GraphConverter:
# step #1: generate graph ir for this method
# step #1: generate graph ir for this method
method_ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=-
100
,
name
=
'temp_graph'
,
_internal
=
True
)
method_ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=-
100
,
name
=
'temp_graph'
,
_internal
=
True
)
self
.
handle_graph_nodes
(
script_module
,
script_method
.
graph
,
module
,
self
.
handle_graph_nodes
(
script_module
,
script_method
.
graph
,
module
,
module_name
,
ir_model
,
method_ir_graph
,
shared_module_index
)
module_name
,
module_python_name
,
ir_model
,
method_ir_graph
,
shared_module_index
)
self
.
refine_graph
(
method_ir_graph
)
self
.
refine_graph
(
method_ir_graph
)
# step #2: merge this graph to its module graph
# step #2: merge this graph to its module graph
...
@@ -439,6 +448,8 @@ class GraphConverter:
...
@@ -439,6 +448,8 @@ class GraphConverter:
self
.
global_seq
+=
1
self
.
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
self
.
global_seq
),
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
self
.
global_seq
),
'{}.{}'
.
format
(
func_type_str
,
func_name
))
'{}.{}'
.
format
(
func_type_str
,
func_name
))
func_python_name
=
build_python_name
(
module_python_name
,
func_name
)
func_node
.
python_name
=
func_python_name
node_index
[
node
]
=
func_node
node_index
[
node
]
=
func_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
elif
node
.
kind
()
==
'prim::Constant'
:
elif
node
.
kind
()
==
'prim::Constant'
:
...
@@ -480,7 +491,10 @@ class GraphConverter:
...
@@ -480,7 +491,10 @@ class GraphConverter:
# handle aten::XXX
# handle aten::XXX
self
.
global_seq
+=
1
self
.
global_seq
+=
1
aten_op_name
=
node
.
kind
().
replace
(
'::'
,
'__'
)
aten_op_name
=
node
.
kind
().
replace
(
'::'
,
'__'
)
aten_op_python_name
=
node
.
kind
().
replace
(
'aten::'
,
''
)
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
aten_op_name
,
self
.
global_seq
),
node
.
kind
())
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
aten_op_name
,
self
.
global_seq
),
node
.
kind
())
aten_python_name
=
build_python_name
(
module_python_name
,
aten_op_python_name
)
aten_node
.
python_name
=
aten_python_name
node_index
[
node
]
=
aten_node
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
else
:
else
:
...
@@ -587,25 +601,29 @@ class GraphConverter:
...
@@ -587,25 +601,29 @@ class GraphConverter:
'accessor'
:
module
.
_accessor
'accessor'
:
module
.
_accessor
}
}
def
_convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
def
_convert_module
(
self
,
script_module
,
module
,
module_name
,
module_python_name
,
ir_model
):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
# also has LayerChoice or InputChoice or ValueChoice
original_type_name
=
script_module
.
original_name
original_type_name
=
script_module
.
original_name
m_attrs
=
None
m_attrs
=
None
if
original_type_name
==
OpTypeName
.
LayerChoice
:
if
original_type_name
==
OpTypeName
.
LayerChoice
:
graph
=
Graph
(
ir_model
,
-
100
,
module_name
,
_internal
=
True
)
# graph_id is not used now
graph
=
Graph
(
ir_model
,
-
100
,
module_name
,
_internal
=
True
)
# graph_id is not used now
graph
.
python_name
=
module_python_name
candidate_name_list
=
[]
candidate_name_list
=
[]
for
cand_name
in
module
.
names
:
for
cand_name
in
module
.
names
:
cand
=
module
[
cand_name
]
cand
=
module
[
cand_name
]
script_cand
=
script_module
.
_modules
[
cand_name
]
script_cand
=
script_module
.
_modules
[
cand_name
]
cand_name
=
build_cand_name
(
cand_name
,
module
.
label
)
cand_full_name
=
build_cand_name
(
cand_name
,
module
.
label
)
candidate_name_list
.
append
(
cand_name
)
cand_python_name
=
build_python_name
(
module_python_name
,
cand_name
)
subgraph
,
attrs
=
self
.
_convert_module
(
script_cand
,
cand
,
cand_name
,
ir_model
)
candidate_name_list
.
append
(
cand_full_name
)
subgraph
,
attrs
=
self
.
_convert_module
(
script_cand
,
cand
,
cand_full_name
,
cand_python_name
,
ir_model
)
if
subgraph
is
not
None
:
if
subgraph
is
not
None
:
graph
.
add_node
(
subgraph
.
name
,
Cell
(
cell_name
=
subgraph
.
name
,
parameters
=
attrs
))
cand_node
=
graph
.
add_node
(
subgraph
.
name
,
Cell
(
cell_name
=
subgraph
.
name
,
parameters
=
attrs
))
cand_node
.
python_name
=
cand_python_name
else
:
else
:
cand_type
=
'__torch__.'
+
get_importable_name
(
cand
.
__class__
)
cand_type
=
'__torch__.'
+
get_importable_name
(
cand
.
__class__
)
graph
.
add_node
(
cand_name
,
cand_type
,
attrs
)
cand_node
=
graph
.
add_node
(
cand_full_name
,
cand_type
,
attrs
)
cand_node
.
python_name
=
cand_python_name
graph
.
_register
()
graph
.
_register
()
return
graph
,
{
'mutation'
:
'layerchoice'
,
'label'
:
module
.
label
,
'candidates'
:
candidate_name_list
}
return
graph
,
{
'mutation'
:
'layerchoice'
,
'label'
:
module
.
label
,
'candidates'
:
candidate_name_list
}
elif
original_type_name
==
OpTypeName
.
InputChoice
:
elif
original_type_name
==
OpTypeName
.
InputChoice
:
...
@@ -629,10 +647,11 @@ class GraphConverter:
...
@@ -629,10 +647,11 @@ class GraphConverter:
sm_graph
=
script_module
.
graph
sm_graph
=
script_module
.
graph
self
.
global_graph_id
+=
1
self
.
global_graph_id
+=
1
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
self
.
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
self
.
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
ir_graph
.
python_name
=
module_python_name
# handle graph nodes
# handle graph nodes
self
.
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
self
.
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module
_name
,
ir_model
,
ir_graph
)
module_name
,
module_python
_name
,
ir_model
,
ir_graph
)
self
.
refine_graph
(
ir_graph
)
self
.
refine_graph
(
ir_graph
)
ir_graph
.
_register
()
ir_graph
.
_register
()
...
@@ -671,8 +690,7 @@ class GraphConverter:
...
@@ -671,8 +690,7 @@ class GraphConverter:
dict
dict
the input arguments of this module
the input arguments of this module
"""
"""
return
self
.
_convert_module
(
script_module
,
module
,
module_name
,
None
,
ir_model
)
return
self
.
_convert_module
(
script_module
,
module
,
module_name
,
ir_model
)
class
GraphConverterWithShape
(
GraphConverter
):
class
GraphConverterWithShape
(
GraphConverter
):
...
@@ -691,7 +709,7 @@ class GraphConverterWithShape(GraphConverter):
...
@@ -691,7 +709,7 @@ class GraphConverterWithShape(GraphConverter):
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
,
dummy_input
):
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
,
dummy_input
):
module
.
eval
()
module
.
eval
()
ir_graph
,
attrs
=
self
.
_convert_module
(
script_module
,
module
,
module_name
,
ir_model
)
ir_graph
,
attrs
=
self
.
_convert_module
(
script_module
,
module
,
module_name
,
None
,
ir_model
)
self
.
remove_dummy_nodes
(
ir_model
)
self
.
remove_dummy_nodes
(
ir_model
)
self
.
_initialize_parameters
(
ir_model
)
self
.
_initialize_parameters
(
ir_model
)
self
.
_trace_module
(
module
,
module_name
,
ir_model
,
dummy_input
)
self
.
_trace_module
(
module
,
module_name
,
ir_model
,
dummy_input
)
...
...
nni/retiarii/converter/utils.py
View file @
d50b4665
...
@@ -14,6 +14,15 @@ def build_full_name(prefix, name, seq=None):
...
@@ -14,6 +14,15 @@ def build_full_name(prefix, name, seq=None):
return
'{}__{}{}'
.
format
(
prefix
,
name
,
str
(
seq
))
return
'{}__{}{}'
.
format
(
prefix
,
name
,
str
(
seq
))
def
build_python_name
(
prefix
,
name
):
if
isinstance
(
name
,
list
):
name
=
'.'
.
join
(
name
)
if
prefix
:
return
'{}.{}'
.
format
(
prefix
,
name
)
else
:
# predix could be None
return
name
def
build_cand_name
(
name
,
label
):
def
build_cand_name
(
name
,
label
):
return
f
'layerchoice_
{
label
}
_
{
name
}
'
return
f
'layerchoice_
{
label
}
_
{
name
}
'
...
...
nni/retiarii/graph.py
View file @
d50b4665
...
@@ -212,6 +212,20 @@ class Model:
...
@@ -212,6 +212,20 @@ class Model:
else
:
else
:
return
None
return
None
def
get_node_by_python_name
(
self
,
python_name
:
str
)
->
'Node'
:
"""
Traverse all the nodes to find the matched node with the given python_name.
"""
matched_nodes
=
[]
for
graph
in
self
.
graphs
.
values
():
nodes
=
graph
.
get_nodes_by_python_name
(
python_name
)
matched_nodes
.
extend
(
nodes
)
# assert len(matched_nodes) <= 1
if
matched_nodes
:
return
matched_nodes
[
0
]
else
:
return
None
def
get_cell_nodes
(
self
)
->
List
[
'Node'
]:
def
get_cell_nodes
(
self
)
->
List
[
'Node'
]:
matched_nodes
=
[]
matched_nodes
=
[]
for
graph
in
self
.
graphs
.
values
():
for
graph
in
self
.
graphs
.
values
():
...
@@ -274,6 +288,8 @@ class Graph:
...
@@ -274,6 +288,8 @@ class Graph:
All input/output/hidden nodes.
All input/output/hidden nodes.
edges
edges
...
...
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
"""
def
__init__
(
self
,
model
:
Model
,
graph_id
:
int
,
name
:
str
=
None
,
_internal
:
bool
=
False
):
def
__init__
(
self
,
model
:
Model
,
graph_id
:
int
,
name
:
str
=
None
,
_internal
:
bool
=
False
):
...
@@ -283,6 +299,9 @@ class Graph:
...
@@ -283,6 +299,9 @@ class Graph:
self
.
id
:
int
=
graph_id
self
.
id
:
int
=
graph_id
self
.
name
:
str
=
name
or
f
'_generated_
{
graph_id
}
'
self
.
name
:
str
=
name
or
f
'_generated_
{
graph_id
}
'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self
.
python_name
:
Optional
[
str
]
=
None
self
.
input_node
:
Node
=
Node
(
self
,
_InputPseudoUid
,
'_inputs'
,
_IOPseudoOperation
(
'_inputs'
),
_internal
=
True
)
self
.
input_node
:
Node
=
Node
(
self
,
_InputPseudoUid
,
'_inputs'
,
_IOPseudoOperation
(
'_inputs'
),
_internal
=
True
)
self
.
output_node
:
Node
=
Node
(
self
,
_OutputPseudoUid
,
'_outputs'
,
_IOPseudoOperation
(
'_outputs'
),
_internal
=
True
)
self
.
output_node
:
Node
=
Node
(
self
,
_OutputPseudoUid
,
'_outputs'
,
_IOPseudoOperation
(
'_outputs'
),
_internal
=
True
)
self
.
hidden_nodes
:
List
[
Node
]
=
[]
self
.
hidden_nodes
:
List
[
Node
]
=
[]
...
@@ -355,6 +374,13 @@ class Graph:
...
@@ -355,6 +374,13 @@ class Graph:
found
=
[
node
for
node
in
self
.
nodes
if
node
.
name
==
name
]
found
=
[
node
for
node
in
self
.
nodes
if
node
.
name
==
name
]
return
found
[
0
]
if
found
else
None
return
found
[
0
]
if
found
else
None
def
get_node_by_python_name
(
self
,
python_name
:
str
)
->
Optional
[
'Node'
]:
"""
Returns the node which has specified python_name; or returns `None` if no node has this python_name.
"""
found
=
[
node
for
node
in
self
.
nodes
if
node
.
python_name
==
python_name
]
return
found
[
0
]
if
found
else
None
def
get_nodes_by_type
(
self
,
operation_type
:
str
)
->
List
[
'Node'
]:
def
get_nodes_by_type
(
self
,
operation_type
:
str
)
->
List
[
'Node'
]:
"""
"""
Returns nodes whose operation is specified typed.
Returns nodes whose operation is specified typed.
...
@@ -374,6 +400,9 @@ class Graph:
...
@@ -374,6 +400,9 @@ class Graph:
def
get_nodes_by_name
(
self
,
name
:
str
)
->
List
[
'Node'
]:
def
get_nodes_by_name
(
self
,
name
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
name
==
name
]
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
name
==
name
]
def
get_nodes_by_python_name
(
self
,
python_name
:
str
)
->
Optional
[
'Node'
]:
return
[
node
for
node
in
self
.
nodes
if
node
.
python_name
==
python_name
]
def
topo_sort
(
self
)
->
List
[
'Node'
]:
def
topo_sort
(
self
)
->
List
[
'Node'
]:
node_to_fanin
=
{}
node_to_fanin
=
{}
curr_nodes
=
[]
curr_nodes
=
[]
...
@@ -423,9 +452,11 @@ class Graph:
...
@@ -423,9 +452,11 @@ class Graph:
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
new_graph
.
python_name
=
self
.
python_name
for
node
in
self
.
hidden_nodes
:
for
node
in
self
.
hidden_nodes
:
new_node
=
Node
(
new_graph
,
node
.
id
,
node
.
name
,
node
.
operation
,
_internal
=
True
)
new_node
=
Node
(
new_graph
,
node
.
id
,
node
.
name
,
node
.
operation
,
_internal
=
True
)
new_node
.
python_name
=
node
.
python_name
new_node
.
update_label
(
node
.
label
)
new_node
.
update_label
(
node
.
label
)
new_node
.
_register
()
new_node
.
_register
()
...
@@ -446,11 +477,13 @@ class Graph:
...
@@ -446,11 +477,13 @@ class Graph:
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
new_graph
.
python_name
=
self
.
python_name
id_to_new_node
=
{}
# old node ID -> new node object
id_to_new_node
=
{}
# old node ID -> new node object
for
old_node
in
self
.
hidden_nodes
:
for
old_node
in
self
.
hidden_nodes
:
new_node
=
Node
(
new_graph
,
uid
(),
None
,
old_node
.
operation
,
_internal
=
True
).
_register
()
new_node
=
Node
(
new_graph
,
uid
(),
None
,
old_node
.
operation
,
_internal
=
True
).
_register
()
new_node
.
python_name
=
old_node
.
python_name
new_node
.
update_label
(
old_node
.
label
)
new_node
.
update_label
(
old_node
.
label
)
id_to_new_node
[
old_node
.
id
]
=
new_node
id_to_new_node
[
old_node
.
id
]
=
new_node
...
@@ -514,6 +547,8 @@ class Node:
...
@@ -514,6 +547,8 @@ class Node:
If two models have nodes with same ID, they are semantically the same node.
If two models have nodes with same ID, they are semantically the same node.
name
name
Mnemonic name. It should have an one-to-one mapping with ID.
Mnemonic name. It should have an one-to-one mapping with ID.
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
label
label
Optional. If two nodes have the same label, they are considered same by the mutator.
Optional. If two nodes have the same label, they are considered same by the mutator.
operation
operation
...
@@ -535,13 +570,15 @@ class Node:
...
@@ -535,13 +570,15 @@ class Node:
self
.
graph
:
Graph
=
graph
self
.
graph
:
Graph
=
graph
self
.
id
:
int
=
node_id
self
.
id
:
int
=
node_id
self
.
name
:
str
=
name
or
f
'_generated_
{
node_id
}
'
self
.
name
:
str
=
name
or
f
'_generated_
{
node_id
}
'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self
.
python_name
:
Optional
[
str
]
=
None
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
# maybe we should copy it here or make Operation class immutable, in next release
self
.
operation
:
Operation
=
operation
self
.
operation
:
Operation
=
operation
self
.
label
:
Optional
[
str
]
=
None
self
.
label
:
Optional
[
str
]
=
None
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'Node(id=
{
self
.
id
}
, name=
{
self
.
name
}
, label=
{
self
.
label
}
, operation=
{
self
.
operation
}
)'
return
f
'Node(id=
{
self
.
id
}
, name=
{
self
.
name
}
,
python_name=
{
self
.
python_name
}
,
label=
{
self
.
label
}
, operation=
{
self
.
operation
}
)'
@
property
@
property
def
predecessors
(
self
)
->
List
[
'Node'
]:
def
predecessors
(
self
)
->
List
[
'Node'
]:
...
@@ -626,6 +663,8 @@ class Node:
...
@@ -626,6 +663,8 @@ class Node:
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
if
self
.
label
is
not
None
:
if
self
.
label
is
not
None
:
ret
[
'label'
]
=
self
.
label
ret
[
'label'
]
=
self
.
label
if
self
.
python_name
is
not
None
:
ret
[
'python_name'
]
=
self
.
python_name
return
ret
return
ret
...
...
test/ut/retiarii/test_convert_pytorch.py
View file @
d50b4665
...
@@ -1232,5 +1232,33 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
...
@@ -1232,5 +1232,33 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
x
=
torch
.
randn
(
5
,
3
,
2
)
x
=
torch
.
randn
(
5
,
3
,
2
)
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
def
test_python_name
(
self
):
from
.inject_nn
import
inject_pytorch_nn
,
remove_inject_pytorch_nn
try
:
inject_pytorch_nn
()
torchvision_model_zoo
=
{
'resnet18'
:
torchvision
.
models
.
resnet18
(),
'alexnet'
:
torchvision
.
models
.
alexnet
(),
'vgg16'
:
torchvision
.
models
.
vgg16
(),
'squeezenet'
:
torchvision
.
models
.
squeezenet1_0
(),
'shufflenet_v2'
:
torchvision
.
models
.
shufflenet_v2_x1_0
(),
'mobilenet_v2'
:
torchvision
.
models
.
mobilenet_v2
(),
'resnext50_32x4d'
:
torchvision
.
models
.
resnext50_32x4d
(),
'wide_resnet50_2'
:
torchvision
.
models
.
wide_resnet50_2
(),
'mnasnet'
:
torchvision
.
models
.
mnasnet1_0
(),
}
dummy_input
=
torch
.
randn
(
1
,
3
,
224
,
224
)
for
model
in
torchvision_model_zoo
.
values
():
model_ir
=
self
.
_convert_model
(
model
,
dummy_input
)
current_name
=
[
node
.
python_name
for
node
in
model_ir
.
get_nodes
()
if
node
.
python_name
]
mentioned
=
set
()
for
k
in
model
.
state_dict
():
k
=
"."
.
join
(
k
.
split
(
"."
)[:
-
1
])
if
k
not
in
mentioned
:
assert
k
in
current_name
,
f
'
{
k
}
not in state_name'
mentioned
.
add
(
k
)
finally
:
remove_inject_pytorch_nn
()
class
TestPytorchWithShape
(
TestPytorch
,
ConvertWithShapeMixin
):
class
TestPytorchWithShape
(
TestPytorch
,
ConvertWithShapeMixin
):
pass
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment