Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9444e275
Unverified
Commit
9444e275
authored
May 22, 2021
by
QuanluZhang
Committed by
GitHub
May 22, 2021
Browse files
Support nested ModuleList and fix an issue in list append (#3652)
parent
ac14b9e4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
119 additions
and
13 deletions
+119
-13
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+28
-13
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+91
-0
No files found.
nni/retiarii/converter/graph_gen.py
View file @
9444e275
...
...
@@ -20,17 +20,17 @@ class GraphConverter:
self
.
global_graph_id
=
0
def
_add_edge_handle_source_node
(
self
,
_input
,
graph_inputs
,
ir_graph
,
output_remap
,
node_index
):
if
_input
in
graph_inputs
:
idx
=
graph_inputs
.
index
(
_input
)
src_node
=
ir_graph
.
input_node
src_node_idx
=
idx
elif
_input
in
output_remap
:
if
_input
in
output_remap
:
assert
output_remap
[
_input
].
kind
()
==
'aten::append'
predecessor_node
=
output_remap
[
_input
]
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
src_node_idx
=
None
src_node
=
node_index
[
predecessor_node
]
assert
isinstance
(
src_node
,
Node
)
elif
_input
in
graph_inputs
:
idx
=
graph_inputs
.
index
(
_input
)
src_node
=
ir_graph
.
input_node
src_node_idx
=
idx
else
:
predecessor_node
=
_input
.
node
()
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
...
...
@@ -315,16 +315,31 @@ class GraphConverter:
if
submodule
.
inputsAt
(
0
).
type
().
name
()
==
'ModuleList'
:
# handle ModuleList
predecessor
=
submodule
.
inputsAt
(
0
).
node
()
module_name_space
=
[
submodule_name
]
while
predecessor
.
inputsAt
(
0
).
debugName
()
!=
'self'
:
# this is for dealing with nested ModuleList. below is an example
# %3 : __torch__.torch.nn.modules.container.___torch_mangle_0.ModuleList = prim::GetAttr[name="ops"](%self)
# %5 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="0"](%3)
# %7 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="1"](%3)
# %9 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="2"](%3)
# %11 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="3"](%3)
# %14 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="0"](%5)
# %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="1"](%5)
# %state.2 : Tensor = prim::CallMethod[name="forward"](%14, %x.1) # modulelist.py:18:24
# %state.4 : Tensor = prim::CallMethod[name="forward"](%16, %state.2) # modulelist.py:18:24
assert
predecessor
.
kind
()
==
'prim::GetAttr'
module_name_space
.
append
(
predecessor
.
s
(
'name'
))
predecessor
=
predecessor
.
inputsAt
(
0
).
node
()
assert
predecessor
.
kind
()
==
'prim::GetAttr'
assert
predecessor
.
hasAttribute
(
'name'
)
assert
predecessor
.
inputsAt
(
0
).
debugName
()
==
'self'
predecessor_name
=
predecessor
.
s
(
'name'
)
# TODO: exchange submodule_name and predecessor_nam
e
submodule
_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule
_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_
module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
module_name_space
.
append
(
predecessor
.
s
(
'name'
))
submodule_full_name
=
build_full_name
(
module_name
,
list
(
reversed
(
module_name_space
))
)
submodule_obj
=
modul
e
script_
submodule
=
script_module
for
each_name
in
list
(
reversed
(
module_name_space
)):
submodule_obj
=
getattr
(
submodule_obj
,
each
_name
)
script_sub
module
=
script_
sub
module
.
_modules
[
each_name
]
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_submodule
,
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
...
...
test/ut/retiarii/test_convert_models.py
View file @
9444e275
import
os
import
sys
import
unittest
from
typing
import
(
Dict
)
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
class
TestModels
(
unittest
.
TestCase
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
if
check_value
:
try
:
self
.
assertEqual
(
len
(
converted_output
),
len
(
expected_output
))
for
a
,
b
in
zip
(
converted_output
,
expected_output
):
torch
.
eq
(
a
,
b
)
except
:
self
.
assertEqual
(
converted_output
,
expected_output
)
return
converted_model
def
test_nested_modulelist
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
num_nodes
,
num_ops_per_node
):
super
().
__init__
()
self
.
ops
=
nn
.
ModuleList
()
self
.
num_nodes
=
num_nodes
self
.
num_ops_per_node
=
num_ops_per_node
for
_
in
range
(
num_nodes
):
self
.
ops
.
append
(
nn
.
ModuleList
([
nn
.
Linear
(
16
,
16
)
for
__
in
range
(
num_ops_per_node
)]))
def
forward
(
self
,
x
):
state
=
x
for
ops
in
self
.
ops
:
for
op
in
ops
:
state
=
op
(
state
)
return
state
model
=
Net
(
4
,
2
)
x
=
torch
.
rand
((
16
,
16
),
dtype
=
torch
.
float
)
self
.
run_test
(
model
,
(
x
,
))
def
test_append_input_tensor
(
self
):
from
typing
import
List
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
num_nodes
):
super
().
__init__
()
self
.
ops
=
nn
.
ModuleList
()
self
.
num_nodes
=
num_nodes
for
_
in
range
(
num_nodes
):
self
.
ops
.
append
(
nn
.
Linear
(
16
,
16
))
def
forward
(
self
,
x
:
List
[
torch
.
Tensor
]):
state
=
x
for
ops
in
self
.
ops
:
state
.
append
(
ops
(
state
[
-
1
]))
return
state
[
-
1
]
model
=
Net
(
4
)
x
=
torch
.
rand
((
1
,
16
),
dtype
=
torch
.
float
)
self
.
run_test
(
model
,
([
x
],
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment