Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c447249c
"vscode:/vscode.git/clone" did not exist on "bc26a2faea287cec6ceca03d6b8d4bbcc2e9a635"
Unverified
Commit
c447249c
authored
Mar 01, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 01, 2022
Browse files
Support loading from `state_dict` of supernet (#4544)
parent
6b828681
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
280 additions
and
77 deletions
+280
-77
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+22
-2
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+16
-2
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+11
-1
nni/retiarii/nn/pytorch/nasbench101.py
nni/retiarii/nn/pytorch/nasbench101.py
+2
-0
nni/retiarii/utils.py
nni/retiarii/utils.py
+118
-0
test/ut/retiarii/debug_mnist_pytorch.py
test/ut/retiarii/debug_mnist_pytorch.py
+2
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+3
-13
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+5
-13
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+5
-14
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+5
-13
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+9
-18
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+82
-1
No files found.
nni/retiarii/codegen/pytorch.py
View file @
c447249c
...
...
@@ -6,6 +6,7 @@ import re
from
typing
import
Dict
,
List
,
Tuple
,
Any
from
nni.retiarii.operation_def.torch_op_def
import
ToDevice
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING
from
nni.common.device
import
Device
,
GPUDevice
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
...
...
@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name
=
name
.
replace
(
'/'
,
'__'
)
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
return
re
.
sub
(
'\W|^(?=\d)'
,
'_'
,
name
)
name
=
re
.
sub
(
'\W|^(?=\d)'
,
'_'
,
name
)
if
name
.
startswith
(
'__'
)
and
(
len
(
name
)
>
2
and
name
[
2
]
!=
'_'
):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name
=
name
[
1
:]
elif
name
.
startswith
(
'_'
):
# to avoid conflicts between '_' and '__'
name
=
'i'
+
name
return
name
def
generate_cuda_mapping
(
placement
:
Dict
[
Node
,
Device
])
->
Dict
[
Device
,
int
]:
...
...
@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
# only need to generate code for module here
import_pkgs
=
set
()
node_codes
=
[]
node_python_mappings
=
{}
cuda_remapped_id
=
None
if
placement
:
cuda_remapped_id
=
generate_cuda_mapping
(
placement
)
...
...
@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
pkg_name
=
node
.
operation
.
get_import_pkg
()
if
pkg_name
is
not
None
:
import_pkgs
.
add
(
pkg_name
)
node_code
=
node
.
operation
.
to_init_code
(
_format_variable_name
(
node
.
name
,
graph_name
))
py_variable_name
=
_format_variable_name
(
node
.
name
,
graph_name
)
node_code
=
node
.
operation
.
to_init_code
(
py_variable_name
)
if
node_code
is
not
None
:
if
placement
and
node
in
placement
and
len
(
node_code
)
>
0
:
if
isinstance
(
placement
[
node
],
GPUDevice
):
...
...
@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
else
:
node_codes
.
append
(
node_code
)
# Map to module hierarchies in original search space python code
node_python_mappings
[
py_variable_name
]
=
node
.
python_name
node_codes
.
append
(
f
'self.
{
STATE_DICT_PY_MAPPING
}
=
{
node_python_mappings
}
'
)
if
graph
.
input_node
.
operation
.
io_names
is
None
:
input_code
=
'*_inputs'
else
:
...
...
nni/retiarii/nn/pytorch/api.py
View file @
c447249c
...
...
@@ -11,6 +11,7 @@ import torch.nn as nn
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
...
...
@@ -65,9 +66,22 @@ class LayerChoice(Mutable):
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
chosen
=
get_fixed_value
(
label
)
if
isinstance
(
candidates
,
list
):
re
turn
candidates
[
int
(
chosen
)]
re
sult
=
candidates
[
int
(
chosen
)]
else
:
return
candidates
[
chosen
]
result
=
candidates
[
chosen
]
# map the named hierarchies to support weight inheritance for python engine
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# handle cases where layer choices are nested
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'
{
chosen
}
.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
# "result" needs to know where to map itself.
# Ideally, we should put a _mapping_ in the module where "result" is located,
# but it's impossible to put mapping into parent module here.
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
str
(
chosen
)})
return
result
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
...
...
nni/retiarii/nn/pytorch/component.py
View file @
c447249c
...
...
@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
...
...
@@ -38,7 +40,15 @@ class Repeat(Mutable):
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
repeat
=
get_fixed_value
(
label
)
return
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
result
=
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'blocks.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
'blocks'
})
return
result
def
__init__
(
self
,
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
...
...
nni/retiarii/nn/pytorch/nasbench101.py
View file @
c447249c
...
...
@@ -304,6 +304,8 @@ class NasBench101Cell(Mutable):
[
op_candidates
[
selected
[
f
'
{
label
}
/op
{
i
}
'
]]
for
i
in
range
(
1
,
num_nodes
-
1
)],
adjacency_list
,
in_features
,
out_features
,
num_nodes
,
projection
)
# FIXME: weight inheritance on nasbench101 is not supported yet
def
__init__
(
self
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
max_num_nodes
:
int
=
7
,
max_num_edges
:
int
=
9
,
label
:
Optional
[
str
]
=
None
):
...
...
nni/retiarii/utils.py
View file @
c447249c
...
...
@@ -2,8 +2,10 @@
# Licensed under the MIT license.
import
inspect
import
itertools
import
warnings
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Dict
from
pathlib
import
Path
...
...
@@ -154,3 +156,119 @@ class ModelNamespace:
def
get_current_context
(
key
:
str
)
->
Any
:
return
ContextStack
.
top
(
key
)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING
=
'_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL
=
'_mapping_partial_'
@
contextmanager
def
original_state_dict_hooks
(
model
:
Any
):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import
torch.nn
as
nn
assert
isinstance
(
model
,
nn
.
Module
),
'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping
=
{}
def
full_mapping_in_module
(
src_prefix
,
tar_prefix
,
module
):
if
hasattr
(
module
,
STATE_DICT_PY_MAPPING
):
# only values are complete
local_map
=
getattr
(
module
,
STATE_DICT_PY_MAPPING
)
elif
hasattr
(
module
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# keys and values are both incomplete
local_map
=
getattr
(
module
,
STATE_DICT_PY_MAPPING_PARTIAL
)
local_map
=
{
k
:
tar_prefix
+
v
for
k
,
v
in
local_map
.
items
()}
else
:
# no mapping
local_map
=
{}
if
'__self__'
in
local_map
:
# special case, overwrite prefix
tar_prefix
=
local_map
[
'__self__'
]
+
'.'
for
key
,
value
in
local_map
.
items
():
if
key
!=
''
and
key
not
in
module
.
_modules
:
# not a sub-module, probably a parameter
full_mapping
[
src_prefix
+
key
]
=
value
if
src_prefix
!=
tar_prefix
:
# To deal with leaf nodes.
for
name
,
value
in
itertools
.
chain
(
module
.
_parameters
.
items
(),
module
.
_buffers
.
items
()):
# direct children
if
value
is
None
or
name
in
module
.
_non_persistent_buffers_set
:
# it won't appear in state dict
continue
if
(
src_prefix
+
name
)
not
in
full_mapping
:
full_mapping
[
src_prefix
+
name
]
=
tar_prefix
+
name
for
name
,
child
in
module
.
named_children
():
# sub-modules
full_mapping_in_module
(
src_prefix
+
name
+
'.'
,
local_map
.
get
(
name
,
tar_prefix
+
name
)
+
'.'
,
# if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module
(
''
,
''
,
model
)
def
load_state_dict_hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
reverse_mapping
=
defaultdict
(
list
)
for
src
,
tar
in
full_mapping
.
items
():
reverse_mapping
[
tar
].
append
(
src
)
transf_state_dict
=
{}
for
src
,
tar_keys
in
reverse_mapping
.
items
():
if
src
in
state_dict
:
value
=
state_dict
.
pop
(
src
)
for
tar
in
tar_keys
:
transf_state_dict
[
tar
]
=
value
else
:
missing_keys
.
append
(
src
)
state_dict
.
update
(
transf_state_dict
)
def
state_dict_hook
(
module
,
destination
,
prefix
,
local_metadata
):
result
=
{}
for
src
,
tar
in
full_mapping
.
items
():
if
src
in
destination
:
result
[
tar
]
=
destination
.
pop
(
src
)
else
:
raise
KeyError
(
f
'"
{
src
}
" not in state dict, but found in mapping.'
)
destination
.
update
(
result
)
try
:
hooks
=
[]
hooks
.
append
(
model
.
_register_load_state_dict_pre_hook
(
load_state_dict_hook
))
hooks
.
append
(
model
.
_register_state_dict_hook
(
state_dict_hook
))
yield
finally
:
for
hook
in
hooks
:
hook
.
remove
()
test/ut/retiarii/debug_mnist_pytorch.py
View file @
c447249c
...
...
@@ -16,6 +16,7 @@ class _model(nn.Module):
self
.
fc1
=
torch
.
nn
.
Linear
(
out_features
=
256
,
in_features
=
1024
)
self
.
fc2
=
torch
.
nn
.
Linear
(
out_features
=
10
,
in_features
=
256
)
self
.
softmax
=
torch
.
nn
.
Softmax
()
self
.
_mapping_
=
{
'stem'
:
None
,
'flatten'
:
None
,
'fc1'
:
None
,
'fc2'
:
None
,
'softmax'
:
None
}
def
forward
(
self
,
image
):
stem
=
self
.
stem
(
image
)
...
...
@@ -34,6 +35,7 @@ class stem(nn.Module):
self
.
pool1
=
torch
.
nn
.
MaxPool2d
(
kernel_size
=
2
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
=
64
,
in_channels
=
32
,
kernel_size
=
5
)
self
.
pool2
=
torch
.
nn
.
MaxPool2d
(
kernel_size
=
2
)
self
.
_mapping_
=
{
'conv1'
:
None
,
'pool1'
:
None
,
'conv2'
:
None
,
'pool2'
:
None
}
def
forward
(
self
,
*
_inputs
):
conv1
=
self
.
conv1
(
_inputs
[
0
])
...
...
test/ut/retiarii/test_convert.py
View file @
c447249c
...
...
@@ -14,6 +14,7 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
...
...
@@ -50,16 +51,6 @@ class Linear(nn.Module):
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
...
@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
dict
(
model
.
state_dict
()))
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_basic.py
View file @
c447249c
...
...
@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
...
@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
model
.
state_dict
())
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_models.py
View file @
c447249c
...
...
@@ -9,23 +9,13 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestModels
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
...
@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
model
.
state_dict
())
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_operators.py
View file @
c447249c
...
...
@@ -16,6 +16,7 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
...
...
@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class
TestOperators
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
...
@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
model
.
state_dict
())
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_pytorch.py
View file @
c447249c
...
...
@@ -14,28 +14,17 @@ import torch.nn.functional as F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestPytorch
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
,
strict_load
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
...
...
@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
model
.
state_dict
(),
strict
=
strict_load
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
model
=
LargeModel
()
x
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
)
self
.
run_test
(
model
,
(
x
,
))
# emb and lin1 is actually not used so they won't appear in generated model
self
.
run_test
(
model
,
(
x
,
),
strict_load
=
False
)
@
unittest
.
skip
(
'skip for now, as it needs inject_nn'
)
def
test_mobilenet_v2_with_external_data
(
self
):
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
c447249c
...
...
@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model
from
nni.retiarii.nn.pytorch.api
import
ValueChoice
from
nni.retiarii.nn.pytorch.mutator
import
process_evaluator_mutations
,
process_inline_mutation
,
extract_mutation_from_pt_module
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.utils
import
ContextStack
from
nni.retiarii.utils
import
ContextStack
,
original_state_dict_hooks
class
EnumerateSampler
(
Sampler
):
...
...
@@ -123,6 +123,29 @@ class GraphIR(unittest.TestCase):
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model_new
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
i
,
3
,
3
]))
def
test_layer_choice_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
i
,
kernel_size
=
1
)
for
i
in
range
(
1
,
11
)])
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
for
i
in
range
(
1
,
11
):
model_new
=
mutator
.
apply
(
model
)
model_new
=
self
.
_get_converted_pytorch_model
(
model_new
)
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
inp
=
torch
.
randn
(
1
,
3
,
3
,
3
)
a
=
getattr
(
orig_model
.
module
,
str
(
i
-
1
))(
inp
)
b
=
model_new
(
inp
)
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
def
test_nested_layer_choice
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
...
...
@@ -150,6 +173,40 @@ class GraphIR(unittest.TestCase):
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
mutators
[
1
].
apply
(
mutators
[
0
].
apply
(
model
)))(
input
).
size
(),
torch
.
Size
([
1
,
5
,
5
,
5
]))
def
test_nested_layer_choice_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
LayerChoice
([
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
),
nn
.
Conv2d
(
3
,
4
,
kernel_size
=
1
),
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
1
)]),
nn
.
Conv2d
(
3
,
1
,
kernel_size
=
1
)
])
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
mutators
[
1
].
bind_sampler
(
EnumerateSampler
())
input
=
torch
.
randn
(
1
,
3
,
5
,
5
)
for
i
in
range
(
3
):
model_new
=
self
.
_get_converted_pytorch_model
(
mutators
[
1
].
apply
(
mutators
[
0
].
apply
(
model
)))
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
if
i
==
0
:
a
=
getattr
(
getattr
(
orig_model
.
module
,
'0'
),
'0'
)(
input
)
elif
i
==
1
:
a
=
getattr
(
orig_model
.
module
,
'1'
)(
input
)
elif
i
==
2
:
a
=
getattr
(
getattr
(
orig_model
.
module
,
'0'
),
'2'
)(
input
)
b
=
model_new
(
input
)
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
def
test_input_choice
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
...
...
@@ -578,6 +635,30 @@ class GraphIR(unittest.TestCase):
self
.
assertIn
(
1.
,
result
)
def
test_repeat_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
Repeat
(
lambda
index
:
nn
.
Conv2d
(
3
,
3
,
1
),
(
2
,
5
))
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
inp
=
torch
.
randn
(
1
,
3
,
5
,
5
)
for
i
in
range
(
4
):
model_new
=
self
.
_get_converted_pytorch_model
(
mutator
.
apply
(
model
))
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
a
=
nn
.
Sequential
(
*
orig_model
.
module
.
blocks
[:
i
+
2
])(
inp
)
b
=
model_new
(
inp
)
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
def
test_cell
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment