Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d50b4665
Unverified
Commit
d50b4665
authored
Nov 29, 2021
by
Jiahang Xu
Committed by
GitHub
Nov 29, 2021
Browse files
Add python name as Node attribute of graph_gen (#4243)
parent
068775f3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
110 additions
and
16 deletions
+110
-16
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+33
-15
nni/retiarii/converter/utils.py
nni/retiarii/converter/utils.py
+9
-0
nni/retiarii/graph.py
nni/retiarii/graph.py
+40
-1
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+28
-0
No files found.
nni/retiarii/converter/graph_gen.py
View file @
d50b4665
...
...
@@ -14,7 +14,8 @@ from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from
.utils
import
(
_convert_name
,
build_full_name
,
_without_shape_info
,
_extract_info_from_trace_node
,
get_full_name_by_scope_name
,
is_layerchoice_node
,
match_node
,
build_cand_name
is_layerchoice_node
,
match_node
,
build_cand_name
,
build_python_name
)
...
...
@@ -139,7 +140,7 @@ class GraphConverter:
hidden_node
.
remove
()
def
handle_graph_nodes
(
self
,
script_module
,
sm_graph
,
module
,
module_name
,
module
,
module_name
,
module_python_name
,
ir_model
,
ir_graph
,
shared_module_index
=
None
):
"""
...
...
@@ -317,10 +318,12 @@ class GraphConverter:
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_python_name
=
build_python_name
(
module_python_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
_convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
submodule_full_name
,
submodule_python_name
,
ir_model
)
else
:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
...
...
@@ -347,12 +350,14 @@ class GraphConverter:
assert
predecessor
.
hasAttribute
(
'name'
)
module_name_space
.
append
(
predecessor
.
s
(
'name'
))
submodule_full_name
=
build_full_name
(
module_name
,
list
(
reversed
(
module_name_space
)))
submodule_python_name
=
build_python_name
(
module_python_name
,
list
(
reversed
(
module_name_space
)))
submodule_obj
=
module
script_submodule
=
script_module
for
each_name
in
list
(
reversed
(
module_name_space
)):
submodule_obj
=
getattr
(
submodule_obj
,
each_name
)
script_submodule
=
script_submodule
.
_modules
[
each_name
]
subgraph
,
sub_m_attrs
=
self
.
_convert_module
(
script_submodule
,
submodule_obj
,
submodule_full_name
,
ir_model
)
subgraph
,
sub_m_attrs
=
self
.
_convert_module
(
script_submodule
,
submodule_obj
,
submodule_full_name
,
submodule_python_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
...
...
@@ -362,13 +367,16 @@ class GraphConverter:
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
self
.
global_seq
+=
1
shared_node_name
=
build_full_name
(
submodule_full_name
,
''
,
self
.
global_seq
)
shared_node_python_name
=
build_python_name
(
submodule_python_name
,
self
.
global_seq
)
shared_type_operation
=
Operation
.
new
(
'shared'
,
{
'reference'
:
submodule_full_name
})
subcell
=
ir_graph
.
add_node
(
shared_node_name
,
shared_type_operation
)
subcell
.
python_name
=
shared_node_python_name
else
:
# this module is processed for the first time, build cell for it
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
subcell
.
python_name
=
submodule_python_name
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
InputChoice
):
...
...
@@ -377,6 +385,7 @@ class GraphConverter:
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
subcell
.
python_name
=
submodule_python_name
shared_module_index
[
submodule_full_name
]
=
subcell
node_index
[
node
]
=
subcell
# connect the cell into graph
...
...
@@ -391,7 +400,7 @@ class GraphConverter:
# step #1: generate graph ir for this method
method_ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=-
100
,
name
=
'temp_graph'
,
_internal
=
True
)
self
.
handle_graph_nodes
(
script_module
,
script_method
.
graph
,
module
,
module_name
,
ir_model
,
method_ir_graph
,
shared_module_index
)
module_name
,
module_python_name
,
ir_model
,
method_ir_graph
,
shared_module_index
)
self
.
refine_graph
(
method_ir_graph
)
# step #2: merge this graph to its module graph
...
...
@@ -439,6 +448,8 @@ class GraphConverter:
self
.
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
self
.
global_seq
),
'{}.{}'
.
format
(
func_type_str
,
func_name
))
func_python_name
=
build_python_name
(
module_python_name
,
func_name
)
func_node
.
python_name
=
func_python_name
node_index
[
node
]
=
func_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
elif
node
.
kind
()
==
'prim::Constant'
:
...
...
@@ -480,7 +491,10 @@ class GraphConverter:
# handle aten::XXX
self
.
global_seq
+=
1
aten_op_name
=
node
.
kind
().
replace
(
'::'
,
'__'
)
aten_op_python_name
=
node
.
kind
().
replace
(
'aten::'
,
''
)
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
aten_op_name
,
self
.
global_seq
),
node
.
kind
())
aten_python_name
=
build_python_name
(
module_python_name
,
aten_op_python_name
)
aten_node
.
python_name
=
aten_python_name
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
else
:
...
...
@@ -587,25 +601,29 @@ class GraphConverter:
'accessor'
:
module
.
_accessor
}
def
_convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
def
_convert_module
(
self
,
script_module
,
module
,
module_name
,
module_python_name
,
ir_model
):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name
=
script_module
.
original_name
m_attrs
=
None
if
original_type_name
==
OpTypeName
.
LayerChoice
:
graph
=
Graph
(
ir_model
,
-
100
,
module_name
,
_internal
=
True
)
# graph_id is not used now
graph
.
python_name
=
module_python_name
candidate_name_list
=
[]
for
cand_name
in
module
.
names
:
cand
=
module
[
cand_name
]
script_cand
=
script_module
.
_modules
[
cand_name
]
cand_name
=
build_cand_name
(
cand_name
,
module
.
label
)
candidate_name_list
.
append
(
cand_name
)
subgraph
,
attrs
=
self
.
_convert_module
(
script_cand
,
cand
,
cand_name
,
ir_model
)
cand_full_name
=
build_cand_name
(
cand_name
,
module
.
label
)
cand_python_name
=
build_python_name
(
module_python_name
,
cand_name
)
candidate_name_list
.
append
(
cand_full_name
)
subgraph
,
attrs
=
self
.
_convert_module
(
script_cand
,
cand
,
cand_full_name
,
cand_python_name
,
ir_model
)
if
subgraph
is
not
None
:
graph
.
add_node
(
subgraph
.
name
,
Cell
(
cell_name
=
subgraph
.
name
,
parameters
=
attrs
))
cand_node
=
graph
.
add_node
(
subgraph
.
name
,
Cell
(
cell_name
=
subgraph
.
name
,
parameters
=
attrs
))
cand_node
.
python_name
=
cand_python_name
else
:
cand_type
=
'__torch__.'
+
get_importable_name
(
cand
.
__class__
)
graph
.
add_node
(
cand_name
,
cand_type
,
attrs
)
cand_node
=
graph
.
add_node
(
cand_full_name
,
cand_type
,
attrs
)
cand_node
.
python_name
=
cand_python_name
graph
.
_register
()
return
graph
,
{
'mutation'
:
'layerchoice'
,
'label'
:
module
.
label
,
'candidates'
:
candidate_name_list
}
elif
original_type_name
==
OpTypeName
.
InputChoice
:
...
...
@@ -629,10 +647,11 @@ class GraphConverter:
sm_graph
=
script_module
.
graph
self
.
global_graph_id
+=
1
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
self
.
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
ir_graph
.
python_name
=
module_python_name
# handle graph nodes
self
.
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module
_name
,
ir_model
,
ir_graph
)
module_name
,
module_python
_name
,
ir_model
,
ir_graph
)
self
.
refine_graph
(
ir_graph
)
ir_graph
.
_register
()
...
...
@@ -671,8 +690,7 @@ class GraphConverter:
dict
the input arguments of this module
"""
return
self
.
_convert_module
(
script_module
,
module
,
module_name
,
ir_model
)
return
self
.
_convert_module
(
script_module
,
module
,
module_name
,
None
,
ir_model
)
class
GraphConverterWithShape
(
GraphConverter
):
...
...
@@ -691,7 +709,7 @@ class GraphConverterWithShape(GraphConverter):
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
,
dummy_input
):
module
.
eval
()
ir_graph
,
attrs
=
self
.
_convert_module
(
script_module
,
module
,
module_name
,
ir_model
)
ir_graph
,
attrs
=
self
.
_convert_module
(
script_module
,
module
,
module_name
,
None
,
ir_model
)
self
.
remove_dummy_nodes
(
ir_model
)
self
.
_initialize_parameters
(
ir_model
)
self
.
_trace_module
(
module
,
module_name
,
ir_model
,
dummy_input
)
...
...
nni/retiarii/converter/utils.py
View file @
d50b4665
...
...
@@ -14,6 +14,15 @@ def build_full_name(prefix, name, seq=None):
return
'{}__{}{}'
.
format
(
prefix
,
name
,
str
(
seq
))
def
build_python_name
(
prefix
,
name
):
if
isinstance
(
name
,
list
):
name
=
'.'
.
join
(
name
)
if
prefix
:
return
'{}.{}'
.
format
(
prefix
,
name
)
else
:
# predix could be None
return
name
def
build_cand_name
(
name
,
label
):
return
f
'layerchoice_
{
label
}
_
{
name
}
'
...
...
nni/retiarii/graph.py
View file @
d50b4665
...
...
@@ -212,6 +212,20 @@ class Model:
else
:
return
None
def
get_node_by_python_name
(
self
,
python_name
:
str
)
->
'Node'
:
"""
Traverse all the nodes to find the matched node with the given python_name.
"""
matched_nodes
=
[]
for
graph
in
self
.
graphs
.
values
():
nodes
=
graph
.
get_nodes_by_python_name
(
python_name
)
matched_nodes
.
extend
(
nodes
)
# assert len(matched_nodes) <= 1
if
matched_nodes
:
return
matched_nodes
[
0
]
else
:
return
None
def
get_cell_nodes
(
self
)
->
List
[
'Node'
]:
matched_nodes
=
[]
for
graph
in
self
.
graphs
.
values
():
...
...
@@ -274,6 +288,8 @@ class Graph:
All input/output/hidden nodes.
edges
...
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
def
__init__
(
self
,
model
:
Model
,
graph_id
:
int
,
name
:
str
=
None
,
_internal
:
bool
=
False
):
...
...
@@ -283,6 +299,9 @@ class Graph:
self
.
id
:
int
=
graph_id
self
.
name
:
str
=
name
or
f
'_generated_
{
graph_id
}
'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self
.
python_name
:
Optional
[
str
]
=
None
self
.
input_node
:
Node
=
Node
(
self
,
_InputPseudoUid
,
'_inputs'
,
_IOPseudoOperation
(
'_inputs'
),
_internal
=
True
)
self
.
output_node
:
Node
=
Node
(
self
,
_OutputPseudoUid
,
'_outputs'
,
_IOPseudoOperation
(
'_outputs'
),
_internal
=
True
)
self
.
hidden_nodes
:
List
[
Node
]
=
[]
...
...
@@ -355,6 +374,13 @@ class Graph:
found
=
[
node
for
node
in
self
.
nodes
if
node
.
name
==
name
]
return
found
[
0
]
if
found
else
None
def
get_node_by_python_name
(
self
,
python_name
:
str
)
->
Optional
[
'Node'
]:
"""
Returns the node which has specified python_name; or returns `None` if no node has this python_name.
"""
found
=
[
node
for
node
in
self
.
nodes
if
node
.
python_name
==
python_name
]
return
found
[
0
]
if
found
else
None
def
get_nodes_by_type
(
self
,
operation_type
:
str
)
->
List
[
'Node'
]:
"""
Returns nodes whose operation is specified typed.
...
...
@@ -374,6 +400,9 @@ class Graph:
def
get_nodes_by_name
(
self
,
name
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
name
==
name
]
def
get_nodes_by_python_name
(
self
,
python_name
:
str
)
->
Optional
[
'Node'
]:
return
[
node
for
node
in
self
.
nodes
if
node
.
python_name
==
python_name
]
def
topo_sort
(
self
)
->
List
[
'Node'
]:
node_to_fanin
=
{}
curr_nodes
=
[]
...
...
@@ -423,9 +452,11 @@ class Graph:
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
new_graph
.
python_name
=
self
.
python_name
for
node
in
self
.
hidden_nodes
:
new_node
=
Node
(
new_graph
,
node
.
id
,
node
.
name
,
node
.
operation
,
_internal
=
True
)
new_node
.
python_name
=
node
.
python_name
new_node
.
update_label
(
node
.
label
)
new_node
.
_register
()
...
...
@@ -446,11 +477,13 @@ class Graph:
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
output_node
.
update_label
(
self
.
output_node
.
label
)
new_graph
.
python_name
=
self
.
python_name
id_to_new_node
=
{}
# old node ID -> new node object
for
old_node
in
self
.
hidden_nodes
:
new_node
=
Node
(
new_graph
,
uid
(),
None
,
old_node
.
operation
,
_internal
=
True
).
_register
()
new_node
.
python_name
=
old_node
.
python_name
new_node
.
update_label
(
old_node
.
label
)
id_to_new_node
[
old_node
.
id
]
=
new_node
...
...
@@ -514,6 +547,8 @@ class Node:
If two models have nodes with same ID, they are semantically the same node.
name
Mnemonic name. It should have an one-to-one mapping with ID.
python_name
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
label
Optional. If two nodes have the same label, they are considered same by the mutator.
operation
...
...
@@ -535,13 +570,15 @@ class Node:
self
.
graph
:
Graph
=
graph
self
.
id
:
int
=
node_id
self
.
name
:
str
=
name
or
f
'_generated_
{
node_id
}
'
# `python_name` is `None` by default. It should be set after initialization if it is needed.
self
.
python_name
:
Optional
[
str
]
=
None
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self
.
operation
:
Operation
=
operation
self
.
label
:
Optional
[
str
]
=
None
def
__repr__
(
self
):
return
f
'Node(id=
{
self
.
id
}
, name=
{
self
.
name
}
, label=
{
self
.
label
}
, operation=
{
self
.
operation
}
)'
return
f
'Node(id=
{
self
.
id
}
, name=
{
self
.
name
}
,
python_name=
{
self
.
python_name
}
,
label=
{
self
.
label
}
, operation=
{
self
.
operation
}
)'
@
property
def
predecessors
(
self
)
->
List
[
'Node'
]:
...
...
@@ -626,6 +663,8 @@ class Node:
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
if
self
.
label
is
not
None
:
ret
[
'label'
]
=
self
.
label
if
self
.
python_name
is
not
None
:
ret
[
'python_name'
]
=
self
.
python_name
return
ret
...
...
test/ut/retiarii/test_convert_pytorch.py
View file @
d50b4665
...
...
@@ -1232,5 +1232,33 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
x
=
torch
.
randn
(
5
,
3
,
2
)
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
def
test_python_name
(
self
):
from
.inject_nn
import
inject_pytorch_nn
,
remove_inject_pytorch_nn
try
:
inject_pytorch_nn
()
torchvision_model_zoo
=
{
'resnet18'
:
torchvision
.
models
.
resnet18
(),
'alexnet'
:
torchvision
.
models
.
alexnet
(),
'vgg16'
:
torchvision
.
models
.
vgg16
(),
'squeezenet'
:
torchvision
.
models
.
squeezenet1_0
(),
'shufflenet_v2'
:
torchvision
.
models
.
shufflenet_v2_x1_0
(),
'mobilenet_v2'
:
torchvision
.
models
.
mobilenet_v2
(),
'resnext50_32x4d'
:
torchvision
.
models
.
resnext50_32x4d
(),
'wide_resnet50_2'
:
torchvision
.
models
.
wide_resnet50_2
(),
'mnasnet'
:
torchvision
.
models
.
mnasnet1_0
(),
}
dummy_input
=
torch
.
randn
(
1
,
3
,
224
,
224
)
for
model
in
torchvision_model_zoo
.
values
():
model_ir
=
self
.
_convert_model
(
model
,
dummy_input
)
current_name
=
[
node
.
python_name
for
node
in
model_ir
.
get_nodes
()
if
node
.
python_name
]
mentioned
=
set
()
for
k
in
model
.
state_dict
():
k
=
"."
.
join
(
k
.
split
(
"."
)[:
-
1
])
if
k
not
in
mentioned
:
assert
k
in
current_name
,
f
'
{
k
}
not in state_name'
mentioned
.
add
(
k
)
finally
:
remove_inject_pytorch_nn
()
class
TestPytorchWithShape
(
TestPytorch
,
ConvertWithShapeMixin
):
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment