Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c447249c
Unverified
Commit
c447249c
authored
Mar 01, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 01, 2022
Browse files
Support loading from `state_dict` of supernet (#4544)
parent
6b828681
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
280 additions
and
77 deletions
+280
-77
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+22
-2
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+16
-2
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+11
-1
nni/retiarii/nn/pytorch/nasbench101.py
nni/retiarii/nn/pytorch/nasbench101.py
+2
-0
nni/retiarii/utils.py
nni/retiarii/utils.py
+118
-0
test/ut/retiarii/debug_mnist_pytorch.py
test/ut/retiarii/debug_mnist_pytorch.py
+2
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+3
-13
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+5
-13
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+5
-14
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+5
-13
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+9
-18
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+82
-1
No files found.
nni/retiarii/codegen/pytorch.py
View file @
c447249c
...
...
@@ -6,6 +6,7 @@ import re
from
typing
import
Dict
,
List
,
Tuple
,
Any
from
nni.retiarii.operation_def.torch_op_def
import
ToDevice
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING
from
nni.common.device
import
Device
,
GPUDevice
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
...
...
@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name
=
name
.
replace
(
'/'
,
'__'
)
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
return
re
.
sub
(
'\W|^(?=\d)'
,
'_'
,
name
)
name
=
re
.
sub
(
'\W|^(?=\d)'
,
'_'
,
name
)
if
name
.
startswith
(
'__'
)
and
(
len
(
name
)
>
2
and
name
[
2
]
!=
'_'
):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name
=
name
[
1
:]
elif
name
.
startswith
(
'_'
):
# to avoid conflicts between '_' and '__'
name
=
'i'
+
name
return
name
def
generate_cuda_mapping
(
placement
:
Dict
[
Node
,
Device
])
->
Dict
[
Device
,
int
]:
...
...
@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
# only need to generate code for module here
import_pkgs
=
set
()
node_codes
=
[]
node_python_mappings
=
{}
cuda_remapped_id
=
None
if
placement
:
cuda_remapped_id
=
generate_cuda_mapping
(
placement
)
...
...
@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
pkg_name
=
node
.
operation
.
get_import_pkg
()
if
pkg_name
is
not
None
:
import_pkgs
.
add
(
pkg_name
)
node_code
=
node
.
operation
.
to_init_code
(
_format_variable_name
(
node
.
name
,
graph_name
))
py_variable_name
=
_format_variable_name
(
node
.
name
,
graph_name
)
node_code
=
node
.
operation
.
to_init_code
(
py_variable_name
)
if
node_code
is
not
None
:
if
placement
and
node
in
placement
and
len
(
node_code
)
>
0
:
if
isinstance
(
placement
[
node
],
GPUDevice
):
...
...
@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
else
:
node_codes
.
append
(
node_code
)
# Map to module hierarchies in original search space python code
node_python_mappings
[
py_variable_name
]
=
node
.
python_name
node_codes
.
append
(
f
'self.
{
STATE_DICT_PY_MAPPING
}
=
{
node_python_mappings
}
'
)
if
graph
.
input_node
.
operation
.
io_names
is
None
:
input_code
=
'*_inputs'
else
:
...
...
nni/retiarii/nn/pytorch/api.py
View file @
c447249c
...
...
@@ -11,6 +11,7 @@ import torch.nn as nn
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
...
...
@@ -65,9 +66,22 @@ class LayerChoice(Mutable):
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
chosen
=
get_fixed_value
(
label
)
if
isinstance
(
candidates
,
list
):
re
turn
candidates
[
int
(
chosen
)]
re
sult
=
candidates
[
int
(
chosen
)]
else
:
return
candidates
[
chosen
]
result
=
candidates
[
chosen
]
# map the named hierarchies to support weight inheritance for python engine
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# handle cases where layer choices are nested
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'
{
chosen
}
.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
# "result" needs to know where to map itself.
# Ideally, we should put a _mapping_ in the module where "result" is located,
# but it's impossible to put mapping into parent module here.
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
str
(
chosen
)})
return
result
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
...
...
nni/retiarii/nn/pytorch/component.py
View file @
c447249c
...
...
@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
...
...
@@ -38,7 +40,15 @@ class Repeat(Mutable):
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
repeat
=
get_fixed_value
(
label
)
return
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
result
=
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'blocks.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
'blocks'
})
return
result
def
__init__
(
self
,
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
...
...
nni/retiarii/nn/pytorch/nasbench101.py
View file @
c447249c
...
...
@@ -304,6 +304,8 @@ class NasBench101Cell(Mutable):
[
op_candidates
[
selected
[
f
'
{
label
}
/op
{
i
}
'
]]
for
i
in
range
(
1
,
num_nodes
-
1
)],
adjacency_list
,
in_features
,
out_features
,
num_nodes
,
projection
)
# FIXME: weight inheritance on nasbench101 is not supported yet
def
__init__
(
self
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
max_num_nodes
:
int
=
7
,
max_num_edges
:
int
=
9
,
label
:
Optional
[
str
]
=
None
):
...
...
nni/retiarii/utils.py
View file @
c447249c
...
...
@@ -2,8 +2,10 @@
# Licensed under the MIT license.
import
inspect
import
itertools
import
warnings
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Dict
from
pathlib
import
Path
...
...
@@ -154,3 +156,119 @@ class ModelNamespace:
def
get_current_context
(
key
:
str
)
->
Any
:
return
ContextStack
.
top
(
key
)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING
=
'_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL
=
'_mapping_partial_'
@
contextmanager
def
original_state_dict_hooks
(
model
:
Any
):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import
torch.nn
as
nn
assert
isinstance
(
model
,
nn
.
Module
),
'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping
=
{}
def
full_mapping_in_module
(
src_prefix
,
tar_prefix
,
module
):
if
hasattr
(
module
,
STATE_DICT_PY_MAPPING
):
# only values are complete
local_map
=
getattr
(
module
,
STATE_DICT_PY_MAPPING
)
elif
hasattr
(
module
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# keys and values are both incomplete
local_map
=
getattr
(
module
,
STATE_DICT_PY_MAPPING_PARTIAL
)
local_map
=
{
k
:
tar_prefix
+
v
for
k
,
v
in
local_map
.
items
()}
else
:
# no mapping
local_map
=
{}
if
'__self__'
in
local_map
:
# special case, overwrite prefix
tar_prefix
=
local_map
[
'__self__'
]
+
'.'
for
key
,
value
in
local_map
.
items
():
if
key
!=
''
and
key
not
in
module
.
_modules
:
# not a sub-module, probably a parameter
full_mapping
[
src_prefix
+
key
]
=
value
if
src_prefix
!=
tar_prefix
:
# To deal with leaf nodes.
for
name
,
value
in
itertools
.
chain
(
module
.
_parameters
.
items
(),
module
.
_buffers
.
items
()):
# direct children
if
value
is
None
or
name
in
module
.
_non_persistent_buffers_set
:
# it won't appear in state dict
continue
if
(
src_prefix
+
name
)
not
in
full_mapping
:
full_mapping
[
src_prefix
+
name
]
=
tar_prefix
+
name
for
name
,
child
in
module
.
named_children
():
# sub-modules
full_mapping_in_module
(
src_prefix
+
name
+
'.'
,
local_map
.
get
(
name
,
tar_prefix
+
name
)
+
'.'
,
# if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module
(
''
,
''
,
model
)
def
load_state_dict_hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
reverse_mapping
=
defaultdict
(
list
)
for
src
,
tar
in
full_mapping
.
items
():
reverse_mapping
[
tar
].
append
(
src
)
transf_state_dict
=
{}
for
src
,
tar_keys
in
reverse_mapping
.
items
():
if
src
in
state_dict
:
value
=
state_dict
.
pop
(
src
)
for
tar
in
tar_keys
:
transf_state_dict
[
tar
]
=
value
else
:
missing_keys
.
append
(
src
)
state_dict
.
update
(
transf_state_dict
)
def
state_dict_hook
(
module
,
destination
,
prefix
,
local_metadata
):
result
=
{}
for
src
,
tar
in
full_mapping
.
items
():
if
src
in
destination
:
result
[
tar
]
=
destination
.
pop
(
src
)
else
:
raise
KeyError
(
f
'"
{
src
}
" not in state dict, but found in mapping.'
)
destination
.
update
(
result
)
try
:
hooks
=
[]
hooks
.
append
(
model
.
_register_load_state_dict_pre_hook
(
load_state_dict_hook
))
hooks
.
append
(
model
.
_register_state_dict_hook
(
state_dict_hook
))
yield
finally
:
for
hook
in
hooks
:
hook
.
remove
()
test/ut/retiarii/debug_mnist_pytorch.py
View file @
c447249c
...
...
@@ -16,6 +16,7 @@ class _model(nn.Module):
self
.
fc1
=
torch
.
nn
.
Linear
(
out_features
=
256
,
in_features
=
1024
)
self
.
fc2
=
torch
.
nn
.
Linear
(
out_features
=
10
,
in_features
=
256
)
self
.
softmax
=
torch
.
nn
.
Softmax
()
self
.
_mapping_
=
{
'stem'
:
None
,
'flatten'
:
None
,
'fc1'
:
None
,
'fc2'
:
None
,
'softmax'
:
None
}
def
forward
(
self
,
image
):
stem
=
self
.
stem
(
image
)
...
...
@@ -34,6 +35,7 @@ class stem(nn.Module):
self
.
pool1
=
torch
.
nn
.
MaxPool2d
(
kernel_size
=
2
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
=
64
,
in_channels
=
32
,
kernel_size
=
5
)
self
.
pool2
=
torch
.
nn
.
MaxPool2d
(
kernel_size
=
2
)
self
.
_mapping_
=
{
'conv1'
:
None
,
'pool1'
:
None
,
'conv2'
:
None
,
'pool2'
:
None
}
def
forward
(
self
,
*
_inputs
):
conv1
=
self
.
conv1
(
_inputs
[
0
])
...
...
test/ut/retiarii/test_convert.py
View file @
c447249c
...
...
@@ -14,6 +14,7 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
...
...
@@ -50,16 +51,6 @@ class Linear(nn.Module):
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
...
@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
dict
(
model
.
state_dict
()))
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_basic.py
View file @
c447249c
...
...
@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
...
@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
model
.
state_dict
())
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_models.py
View file @
c447249c
...
...
@@ -9,23 +9,13 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestModels
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
...
@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
model
.
state_dict
())
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_operators.py
View file @
c447249c
...
...
@@ -16,6 +16,7 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
...
...
@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class
TestOperators
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
...
@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
model
.
state_dict
())
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_pytorch.py
View file @
c447249c
...
...
@@ -14,28 +14,17 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestPytorch
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
,
strict_load
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
...
...
@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
model
.
state_dict
(),
strict
=
strict_load
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
model
=
LargeModel
()
x
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
)
self
.
run_test
(
model
,
(
x
,
))
# emb and lin1 is actually not used so they won't appear in generated model
self
.
run_test
(
model
,
(
x
,
),
strict_load
=
False
)
@
unittest
.
skip
(
'skip for now, as it needs inject_nn'
)
def
test_mobilenet_v2_with_external_data
(
self
):
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
c447249c
...
...
@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model
from
nni.retiarii.nn.pytorch.api
import
ValueChoice
from
nni.retiarii.nn.pytorch.mutator
import
process_evaluator_mutations
,
process_inline_mutation
,
extract_mutation_from_pt_module
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.utils
import
ContextStack
from
nni.retiarii.utils
import
ContextStack
,
original_state_dict_hooks
class
EnumerateSampler
(
Sampler
):
...
...
@@ -123,6 +123,29 @@ class GraphIR(unittest.TestCase):
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model_new
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
i
,
3
,
3
]))
def
test_layer_choice_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
i
,
kernel_size
=
1
)
for
i
in
range
(
1
,
11
)])
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
for
i
in
range
(
1
,
11
):
model_new
=
mutator
.
apply
(
model
)
model_new
=
self
.
_get_converted_pytorch_model
(
model_new
)
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
inp
=
torch
.
randn
(
1
,
3
,
3
,
3
)
a
=
getattr
(
orig_model
.
module
,
str
(
i
-
1
))(
inp
)
b
=
model_new
(
inp
)
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
def
test_nested_layer_choice
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
...
...
@@ -150,6 +173,40 @@ class GraphIR(unittest.TestCase):
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
mutators
[
1
].
apply
(
mutators
[
0
].
apply
(
model
)))(
input
).
size
(),
torch
.
Size
([
1
,
5
,
5
,
5
]))
def
test_nested_layer_choice_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
LayerChoice
([
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
),
nn
.
Conv2d
(
3
,
4
,
kernel_size
=
1
),
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
1
)]),
nn
.
Conv2d
(
3
,
1
,
kernel_size
=
1
)
])
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
mutators
[
1
].
bind_sampler
(
EnumerateSampler
())
input
=
torch
.
randn
(
1
,
3
,
5
,
5
)
for
i
in
range
(
3
):
model_new
=
self
.
_get_converted_pytorch_model
(
mutators
[
1
].
apply
(
mutators
[
0
].
apply
(
model
)))
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
if
i
==
0
:
a
=
getattr
(
getattr
(
orig_model
.
module
,
'0'
),
'0'
)(
input
)
elif
i
==
1
:
a
=
getattr
(
orig_model
.
module
,
'1'
)(
input
)
elif
i
==
2
:
a
=
getattr
(
getattr
(
orig_model
.
module
,
'0'
),
'2'
)(
input
)
b
=
model_new
(
input
)
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
def
test_input_choice
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
...
...
@@ -578,6 +635,30 @@ class GraphIR(unittest.TestCase):
self
.
assertIn
(
1.
,
result
)
def
test_repeat_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
Repeat
(
lambda
index
:
nn
.
Conv2d
(
3
,
3
,
1
),
(
2
,
5
))
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
inp
=
torch
.
randn
(
1
,
3
,
5
,
5
)
for
i
in
range
(
4
):
model_new
=
self
.
_get_converted_pytorch_model
(
mutator
.
apply
(
model
))
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
a
=
nn
.
Sequential
(
*
orig_model
.
module
.
blocks
[:
i
+
2
])(
inp
)
b
=
model_new
(
inp
)
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
def
test_cell
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment