Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6f3ed2bf
Unverified
Commit
6f3ed2bf
authored
Mar 22, 2022
by
J-shang
Committed by
GitHub
Mar 22, 2022
Browse files
Merge pull request #4670 from liuzhe-lz/doc-merge
parents
553e91f4
13dc0f8f
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
357 additions
and
125 deletions
+357
-125
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+22
-2
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
+1
-0
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+16
-2
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+11
-1
nni/retiarii/nn/pytorch/nasbench101.py
nni/retiarii/nn/pytorch/nasbench101.py
+2
-0
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+42
-17
nni/retiarii/strategy/_rl_impl.py
nni/retiarii/strategy/_rl_impl.py
+11
-0
nni/retiarii/utils.py
nni/retiarii/utils.py
+118
-0
nni/tools/nnictl/legacy_launcher.py
nni/tools/nnictl/legacy_launcher.py
+6
-10
pipelines/fast-test.yml
pipelines/fast-test.yml
+1
-1
test/ut/compression/v1/test_model_speedup.py
test/ut/compression/v1/test_model_speedup.py
+40
-0
test/ut/compression/v2/test_iterative_pruner_torch.py
test/ut/compression/v2/test_iterative_pruner_torch.py
+27
-9
test/ut/compression/v2/test_pruner_torch.py
test/ut/compression/v2/test_pruner_torch.py
+30
-12
test/ut/retiarii/debug_mnist_pytorch.py
test/ut/retiarii/debug_mnist_pytorch.py
+2
-0
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+1
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+3
-13
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+5
-13
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+5
-14
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+5
-13
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+9
-18
No files found.
nni/retiarii/codegen/pytorch.py
View file @
6f3ed2bf
...
@@ -6,6 +6,7 @@ import re
...
@@ -6,6 +6,7 @@ import re
from
typing
import
Dict
,
List
,
Tuple
,
Any
from
typing
import
Dict
,
List
,
Tuple
,
Any
from
nni.retiarii.operation_def.torch_op_def
import
ToDevice
from
nni.retiarii.operation_def.torch_op_def
import
ToDevice
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING
from
nni.common.device
import
Device
,
GPUDevice
from
nni.common.device
import
Device
,
GPUDevice
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
...
@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str:
...
@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name
=
name
.
replace
(
'/'
,
'__'
)
name
=
name
.
replace
(
'/'
,
'__'
)
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
return
re
.
sub
(
'\W|^(?=\d)'
,
'_'
,
name
)
name
=
re
.
sub
(
'\W|^(?=\d)'
,
'_'
,
name
)
if
name
.
startswith
(
'__'
)
and
(
len
(
name
)
>
2
and
name
[
2
]
!=
'_'
):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name
=
name
[
1
:]
elif
name
.
startswith
(
'_'
):
# to avoid conflicts between '_' and '__'
name
=
'i'
+
name
return
name
def
generate_cuda_mapping
(
placement
:
Dict
[
Node
,
Device
])
->
Dict
[
Device
,
int
]:
def
generate_cuda_mapping
(
placement
:
Dict
[
Node
,
Device
])
->
Dict
[
Device
,
int
]:
...
@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
...
@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
# only need to generate code for module here
# only need to generate code for module here
import_pkgs
=
set
()
import_pkgs
=
set
()
node_codes
=
[]
node_codes
=
[]
node_python_mappings
=
{}
cuda_remapped_id
=
None
cuda_remapped_id
=
None
if
placement
:
if
placement
:
cuda_remapped_id
=
generate_cuda_mapping
(
placement
)
cuda_remapped_id
=
generate_cuda_mapping
(
placement
)
...
@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
...
@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
pkg_name
=
node
.
operation
.
get_import_pkg
()
pkg_name
=
node
.
operation
.
get_import_pkg
()
if
pkg_name
is
not
None
:
if
pkg_name
is
not
None
:
import_pkgs
.
add
(
pkg_name
)
import_pkgs
.
add
(
pkg_name
)
node_code
=
node
.
operation
.
to_init_code
(
_format_variable_name
(
node
.
name
,
graph_name
))
py_variable_name
=
_format_variable_name
(
node
.
name
,
graph_name
)
node_code
=
node
.
operation
.
to_init_code
(
py_variable_name
)
if
node_code
is
not
None
:
if
node_code
is
not
None
:
if
placement
and
node
in
placement
and
len
(
node_code
)
>
0
:
if
placement
and
node
in
placement
and
len
(
node_code
)
>
0
:
if
isinstance
(
placement
[
node
],
GPUDevice
):
if
isinstance
(
placement
[
node
],
GPUDevice
):
...
@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
...
@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
else
:
else
:
node_codes
.
append
(
node_code
)
node_codes
.
append
(
node_code
)
# Map to module hierarchies in original search space python code
node_python_mappings
[
py_variable_name
]
=
node
.
python_name
node_codes
.
append
(
f
'self.
{
STATE_DICT_PY_MAPPING
}
=
{
node_python_mappings
}
'
)
if
graph
.
input_node
.
operation
.
io_names
is
None
:
if
graph
.
input_node
.
operation
.
io_names
is
None
:
input_code
=
'*_inputs'
input_code
=
'*_inputs'
else
:
else
:
...
...
nni/retiarii/evaluator/pytorch/cgo/evaluator.py
View file @
6f3ed2bf
...
@@ -101,6 +101,7 @@ class _MultiModelSupervisedLearningModule(LightningModule):
...
@@ -101,6 +101,7 @@ class _MultiModelSupervisedLearningModule(LightningModule):
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
return
{
name
:
self
.
trainer
.
callback_metrics
[
'val_'
+
name
].
item
()
for
name
in
self
.
metrics
}
@
nni
.
trace
class
MultiModelSupervisedLearningModule
(
_MultiModelSupervisedLearningModule
):
class
MultiModelSupervisedLearningModule
(
_MultiModelSupervisedLearningModule
):
"""
"""
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
...
...
nni/retiarii/nn/pytorch/api.py
View file @
6f3ed2bf
...
@@ -11,6 +11,7 @@ import torch.nn as nn
...
@@ -11,6 +11,7 @@ import torch.nn as nn
from
nni.common.serializer
import
Translatable
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
...
@@ -82,9 +83,22 @@ class LayerChoice(Mutable):
...
@@ -82,9 +83,22 @@ class LayerChoice(Mutable):
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
chosen
=
get_fixed_value
(
label
)
chosen
=
get_fixed_value
(
label
)
if
isinstance
(
candidates
,
list
):
if
isinstance
(
candidates
,
list
):
re
turn
candidates
[
int
(
chosen
)]
re
sult
=
candidates
[
int
(
chosen
)]
else
:
else
:
return
candidates
[
chosen
]
result
=
candidates
[
chosen
]
# map the named hierarchies to support weight inheritance for python engine
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# handle cases where layer choices are nested
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'
{
chosen
}
.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
# "result" needs to know where to map itself.
# Ideally, we should put a _mapping_ in the module where "result" is located,
# but it's impossible to put mapping into parent module here.
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
str
(
chosen
)})
return
result
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
...
...
nni/retiarii/nn/pytorch/component.py
View file @
6f3ed2bf
...
@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional
...
@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
from
.api
import
LayerChoice
from
.cell
import
Cell
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
...
@@ -59,7 +61,15 @@ class Repeat(Mutable):
...
@@ -59,7 +61,15 @@ class Repeat(Mutable):
List
[
nn
.
Module
]],
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
repeat
=
get_fixed_value
(
label
)
repeat
=
get_fixed_value
(
label
)
return
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
result
=
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'blocks.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
'blocks'
})
return
result
def
__init__
(
self
,
def
__init__
(
self
,
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
...
...
nni/retiarii/nn/pytorch/nasbench101.py
View file @
6f3ed2bf
...
@@ -301,6 +301,8 @@ class NasBench101Cell(Mutable):
...
@@ -301,6 +301,8 @@ class NasBench101Cell(Mutable):
[
op_candidates
[
selected
[
f
'
{
label
}
/op
{
i
}
'
]]
for
i
in
range
(
1
,
num_nodes
-
1
)],
[
op_candidates
[
selected
[
f
'
{
label
}
/op
{
i
}
'
]]
for
i
in
range
(
1
,
num_nodes
-
1
)],
adjacency_list
,
in_features
,
out_features
,
num_nodes
,
projection
)
adjacency_list
,
in_features
,
out_features
,
num_nodes
,
projection
)
# FIXME: weight inheritance on nasbench101 is not supported yet
def
__init__
(
self
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
def
__init__
(
self
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
max_num_nodes
:
int
=
7
,
max_num_edges
:
int
=
9
,
label
:
Optional
[
str
]
=
None
):
max_num_nodes
:
int
=
7
,
max_num_edges
:
int
=
9
,
label
:
Optional
[
str
]
=
None
):
...
...
nni/retiarii/serializer.py
View file @
6f3ed2bf
...
@@ -6,7 +6,7 @@ import os
...
@@ -6,7 +6,7 @@ import os
import
warnings
import
warnings
from
typing
import
Any
,
TypeVar
,
Union
from
typing
import
Any
,
TypeVar
,
Union
from
nni.common.serializer
import
Traceable
,
is_traceable
,
trace
,
_copy_class_wrapper_attributes
from
nni.common.serializer
import
Traceable
,
is_traceable
,
is_wrapped_with_trace
,
trace
,
_copy_class_wrapper_attributes
from
.utils
import
ModelNamespace
from
.utils
import
ModelNamespace
__all__
=
[
'get_init_parameters_or_fail'
,
'serialize'
,
'serialize_cls'
,
'basic_unit'
,
'model_wrapper'
,
__all__
=
[
'get_init_parameters_or_fail'
,
'serialize'
,
'serialize_cls'
,
'basic_unit'
,
'model_wrapper'
,
...
@@ -71,7 +71,8 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
...
@@ -71,7 +71,8 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
if
nni_trace_flag
.
lower
()
==
'disable'
:
if
nni_trace_flag
.
lower
()
==
'disable'
:
return
cls
return
cls
_check_wrapped
(
cls
)
if
_check_wrapped
(
cls
,
'basic_unit'
):
return
cls
import
torch.nn
as
nn
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
...
@@ -79,15 +80,7 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
...
@@ -79,15 +80,7 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
cls
=
trace
(
cls
)
cls
=
trace
(
cls
)
cls
.
_nni_basic_unit
=
basic_unit_tag
cls
.
_nni_basic_unit
=
basic_unit_tag
# HACK: for torch script
_torchscript_patch
(
cls
)
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import
torch
cls
.
_get_nni_attr
=
torch
.
jit
.
ignore
(
cls
.
_get_nni_attr
)
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
return
cls
return
cls
...
@@ -116,12 +109,14 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
...
@@ -116,12 +109,14 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
if
nni_trace_flag
.
lower
()
==
'disable'
:
if
nni_trace_flag
.
lower
()
==
'disable'
:
return
cls
return
cls
_check_wrapped
(
cls
)
if
_check_wrapped
(
cls
,
'model_wrapper'
):
return
cls
import
torch.nn
as
nn
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
)
assert
issubclass
(
cls
,
nn
.
Module
)
wrapper
=
trace
(
cls
)
# subclass can still use trace info
wrapper
=
trace
(
cls
,
inheritable
=
True
)
class
reset_wrapper
(
wrapper
):
class
reset_wrapper
(
wrapper
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
@@ -129,8 +124,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
...
@@ -129,8 +124,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
_copy_class_wrapper_attributes
(
wrapper
,
reset_wrapper
)
_copy_class_wrapper_attributes
(
wrapper
,
reset_wrapper
)
reset_wrapper
.
__wrapped__
=
wrapper
.
__wrapped__
reset_wrapper
.
__wrapped__
=
getattr
(
wrapper
,
'
__wrapped__
'
,
wrapper
)
reset_wrapper
.
_nni_model_wrapper
=
True
reset_wrapper
.
_nni_model_wrapper
=
True
reset_wrapper
.
_traced
=
True
_torchscript_patch
(
cls
)
return
reset_wrapper
return
reset_wrapper
...
@@ -146,6 +145,32 @@ def is_model_wrapped(cls_or_instance) -> bool:
...
@@ -146,6 +145,32 @@ def is_model_wrapped(cls_or_instance) -> bool:
return
getattr
(
cls_or_instance
,
'_nni_model_wrapper'
,
False
)
return
getattr
(
cls_or_instance
,
'_nni_model_wrapper'
,
False
)
def
_check_wrapped
(
cls
:
T
)
->
bool
:
def
_check_wrapped
(
cls
:
T
,
rewrap
:
str
)
->
bool
:
if
getattr
(
cls
,
'_traced'
,
False
)
or
getattr
(
cls
,
'_nni_model_wrapper'
,
False
):
wrapped
=
None
raise
TypeError
(
f
'
{
cls
}
is already wrapped with trace wrapper (basic_unit / model_wrapper / trace). Cannot wrap again.'
)
if
is_model_wrapped
(
cls
):
wrapped
=
'model_wrapper'
elif
is_basic_unit
(
cls
):
wrapped
=
'basic_unit'
elif
is_wrapped_with_trace
(
cls
):
wrapped
=
'nni.trace'
if
wrapped
:
if
wrapped
!=
rewrap
:
raise
TypeError
(
f
'
{
cls
}
is already wrapped with
{
wrapped
}
. Cannot rewrap with
{
rewrap
}
.'
)
return
True
return
False
def
_torchscript_patch
(
cls
)
->
None
:
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import
torch
if
hasattr
(
cls
,
'_get_nni_attr'
):
# could not exist on non-linux
cls
.
_get_nni_attr
=
torch
.
jit
.
ignore
(
cls
.
_get_nni_attr
)
if
hasattr
(
cls
,
'trace_symbol'
):
# these must all exist or all non-exist
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
cls
.
trace_copy
=
torch
.
jit
.
ignore
(
cls
.
trace_copy
)
nni/retiarii/strategy/_rl_impl.py
View file @
6f3ed2bf
...
@@ -43,6 +43,17 @@ class MultiThreadEnvWorker(EnvWorker):
...
@@ -43,6 +43,17 @@ class MultiThreadEnvWorker(EnvWorker):
def
reset
(
self
):
def
reset
(
self
):
return
self
.
env
.
reset
()
return
self
.
env
.
reset
()
def
send
(
self
,
action
):
# for tianshou >= 0.4.6
if
action
is
None
:
self
.
result
=
self
.
pool
.
apply_async
(
self
.
env
.
reset
)
else
:
self
.
send_action
(
action
)
def
recv
(
self
):
# for tianshou >= 0.4.6
return
self
.
result
.
get
()
@
staticmethod
@
staticmethod
def
wait
(
*
args
,
**
kwargs
):
def
wait
(
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'Async collect is not supported yet.'
)
raise
NotImplementedError
(
'Async collect is not supported yet.'
)
...
...
nni/retiarii/utils.py
View file @
6f3ed2bf
...
@@ -2,8 +2,10 @@
...
@@ -2,8 +2,10 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
inspect
import
inspect
import
itertools
import
warnings
import
warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Dict
from
typing
import
Any
,
List
,
Dict
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -150,3 +152,119 @@ class ModelNamespace:
...
@@ -150,3 +152,119 @@ class ModelNamespace:
def
get_current_context
(
key
:
str
)
->
Any
:
def
get_current_context
(
key
:
str
)
->
Any
:
return
ContextStack
.
top
(
key
)
return
ContextStack
.
top
(
key
)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING
=
'_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL
=
'_mapping_partial_'
@
contextmanager
def
original_state_dict_hooks
(
model
:
Any
):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import
torch.nn
as
nn
assert
isinstance
(
model
,
nn
.
Module
),
'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping
=
{}
def
full_mapping_in_module
(
src_prefix
,
tar_prefix
,
module
):
if
hasattr
(
module
,
STATE_DICT_PY_MAPPING
):
# only values are complete
local_map
=
getattr
(
module
,
STATE_DICT_PY_MAPPING
)
elif
hasattr
(
module
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# keys and values are both incomplete
local_map
=
getattr
(
module
,
STATE_DICT_PY_MAPPING_PARTIAL
)
local_map
=
{
k
:
tar_prefix
+
v
for
k
,
v
in
local_map
.
items
()}
else
:
# no mapping
local_map
=
{}
if
'__self__'
in
local_map
:
# special case, overwrite prefix
tar_prefix
=
local_map
[
'__self__'
]
+
'.'
for
key
,
value
in
local_map
.
items
():
if
key
!=
''
and
key
not
in
module
.
_modules
:
# not a sub-module, probably a parameter
full_mapping
[
src_prefix
+
key
]
=
value
if
src_prefix
!=
tar_prefix
:
# To deal with leaf nodes.
for
name
,
value
in
itertools
.
chain
(
module
.
_parameters
.
items
(),
module
.
_buffers
.
items
()):
# direct children
if
value
is
None
or
name
in
module
.
_non_persistent_buffers_set
:
# it won't appear in state dict
continue
if
(
src_prefix
+
name
)
not
in
full_mapping
:
full_mapping
[
src_prefix
+
name
]
=
tar_prefix
+
name
for
name
,
child
in
module
.
named_children
():
# sub-modules
full_mapping_in_module
(
src_prefix
+
name
+
'.'
,
local_map
.
get
(
name
,
tar_prefix
+
name
)
+
'.'
,
# if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module
(
''
,
''
,
model
)
def
load_state_dict_hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
reverse_mapping
=
defaultdict
(
list
)
for
src
,
tar
in
full_mapping
.
items
():
reverse_mapping
[
tar
].
append
(
src
)
transf_state_dict
=
{}
for
src
,
tar_keys
in
reverse_mapping
.
items
():
if
src
in
state_dict
:
value
=
state_dict
.
pop
(
src
)
for
tar
in
tar_keys
:
transf_state_dict
[
tar
]
=
value
else
:
missing_keys
.
append
(
src
)
state_dict
.
update
(
transf_state_dict
)
def
state_dict_hook
(
module
,
destination
,
prefix
,
local_metadata
):
result
=
{}
for
src
,
tar
in
full_mapping
.
items
():
if
src
in
destination
:
result
[
tar
]
=
destination
.
pop
(
src
)
else
:
raise
KeyError
(
f
'"
{
src
}
" not in state dict, but found in mapping.'
)
destination
.
update
(
result
)
try
:
hooks
=
[]
hooks
.
append
(
model
.
_register_load_state_dict_pre_hook
(
load_state_dict_hook
))
hooks
.
append
(
model
.
_register_state_dict_hook
(
state_dict_hook
))
yield
finally
:
for
hook
in
hooks
:
hook
.
remove
()
nni/tools/nnictl/legacy_launcher.py
View file @
6f3ed2bf
...
@@ -70,21 +70,17 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
...
@@ -70,21 +70,17 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
node_command
=
os
.
path
.
join
(
entry_dir
,
'node'
)
node_command
=
os
.
path
.
join
(
entry_dir
,
'node'
)
cmds
=
[
node_command
,
'--max-old-space-size=4096'
,
entry_file
,
'--port'
,
str
(
port
),
'--mode'
,
platform
,
\
cmds
=
[
node_command
,
'--max-old-space-size=4096'
,
entry_file
,
'--port'
,
str
(
port
),
'--mode'
,
platform
,
\
'--experiment_id'
,
experiment_id
]
'--experiment_id'
,
experiment_id
]
if
mode
==
'view'
:
cmds
+=
[
'--action'
,
mode
]
cmds
+=
[
'--start_mode'
,
'resume'
]
cmds
+=
[
'--readonly'
,
'true'
]
else
:
cmds
+=
[
'--start_mode'
,
mode
]
if
log_dir
is
not
None
:
if
log_dir
is
not
None
:
cmds
+=
[
'--
log_dir
'
,
log_dir
]
cmds
+=
[
'--
experiments-directory
'
,
log_dir
]
if
log_level
is
not
None
:
if
log_level
is
not
None
:
cmds
+=
[
'--log
_
level'
,
log_level
]
cmds
+=
[
'--log
-
level'
,
log_level
]
if
foreground
:
if
foreground
:
cmds
+=
[
'--foreground'
,
'true'
]
cmds
+=
[
'--foreground'
,
'true'
]
if
url_prefix
:
if
url_prefix
:
_validate_prefix_path
(
url_prefix
)
_validate_prefix_path
(
url_prefix
)
set_prefix_url
(
url_prefix
)
set_prefix_url
(
url_prefix
)
cmds
+=
[
'--url
_
prefix'
,
url_prefix
]
cmds
+=
[
'--url
-
prefix'
,
url_prefix
.
strip
(
'/'
)
]
stdout_full_path
,
stderr_full_path
=
get_log_path
(
experiment_id
)
stdout_full_path
,
stderr_full_path
=
get_log_path
(
experiment_id
)
with
open
(
stdout_full_path
,
'a+'
)
as
stdout_file
,
open
(
stderr_full_path
,
'a+'
)
as
stderr_file
:
with
open
(
stdout_full_path
,
'a+'
)
as
stdout_file
,
open
(
stderr_full_path
,
'a+'
)
as
stderr_file
:
...
@@ -520,9 +516,9 @@ def create_experiment(args):
...
@@ -520,9 +516,9 @@ def create_experiment(args):
try
:
try
:
if
schema
==
1
:
if
schema
==
1
:
launch_experiment
(
args
,
config_v1
,
'
new
'
,
experiment_id
,
1
)
launch_experiment
(
args
,
config_v1
,
'
create
'
,
experiment_id
,
1
)
else
:
else
:
launch_experiment
(
args
,
config_v2
,
'
new
'
,
experiment_id
,
2
)
launch_experiment
(
args
,
config_v2
,
'
create
'
,
experiment_id
,
2
)
except
Exception
as
exception
:
except
Exception
as
exception
:
restServerPid
=
Experiments
().
get_all_experiments
().
get
(
experiment_id
,
{}).
get
(
'pid'
)
restServerPid
=
Experiments
().
get_all_experiments
().
get
(
experiment_id
,
{}).
get
(
'pid'
)
if
restServerPid
:
if
restServerPid
:
...
...
pipelines/fast-test.yml
View file @
6f3ed2bf
...
@@ -177,7 +177,7 @@ stages:
...
@@ -177,7 +177,7 @@ stages:
-
job
:
windows
-
job
:
windows
pool
:
pool
:
vmImage
:
windows-latest
vmImage
:
windows-latest
timeoutInMinutes
:
7
0
timeoutInMinutes
:
7
5
steps
:
steps
:
-
template
:
templates/install-dependencies.yml
-
template
:
templates/install-dependencies.yml
...
...
test/ut/compression/v1/test_model_speedup.py
View file @
6f3ed2bf
...
@@ -512,6 +512,46 @@ class SpeedupTestCase(TestCase):
...
@@ -512,6 +512,46 @@ class SpeedupTestCase(TestCase):
print
(
"Fine-grained speeduped model"
)
print
(
"Fine-grained speeduped model"
)
print
(
model
)
print
(
model
)
def
test_multiplication_speedup
(
self
):
"""
Model from issue 4540.
"""
class
Net
(
torch
.
nn
.
Module
):
def
__init__
(
self
,):
super
(
Net
,
self
).
__init__
()
self
.
avgpool
=
torch
.
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
input
=
torch
.
nn
.
Conv2d
(
3
,
8
,
3
)
self
.
bn
=
torch
.
nn
.
BatchNorm2d
(
8
)
self
.
fc1
=
torch
.
nn
.
Conv2d
(
8
,
16
,
1
)
self
.
fc2
=
torch
.
nn
.
Conv2d
(
16
,
8
,
1
)
self
.
activation
=
torch
.
nn
.
ReLU
()
self
.
scale_activation
=
torch
.
nn
.
Hardsigmoid
()
self
.
out
=
torch
.
nn
.
Conv2d
(
8
,
12
,
1
)
def
forward
(
self
,
input
):
input
=
self
.
activation
(
self
.
bn
(
self
.
input
(
input
)))
scale
=
self
.
avgpool
(
input
)
out1
=
self
.
activation
(
self
.
fc1
(
scale
))
out1
=
self
.
scale_activation
(
self
.
fc2
(
out1
))
return
self
.
out
(
out1
*
input
)
model
=
Net
().
to
(
device
)
model
.
eval
()
im
=
torch
.
ones
(
1
,
3
,
512
,
512
).
to
(
device
)
model
(
im
)
cfg_list
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
torch
.
nn
.
Conv2d
):
cfg_list
.
append
({
'op_types'
:[
'Conv2d'
],
'sparsity'
:
0.3
,
'op_names'
:[
name
]})
pruner
=
L1FilterPruner
(
model
,
cfg_list
)
pruner
.
compress
()
pruner
.
export_model
(
MODEL_FILE
,
MASK_FILE
)
pruner
.
_unwrap_model
()
ms
=
ModelSpeedup
(
model
,
im
,
MASK_FILE
)
ms
.
speedup_model
()
def
tearDown
(
self
):
def
tearDown
(
self
):
if
os
.
path
.
exists
(
MODEL_FILE
):
if
os
.
path
.
exists
(
MODEL_FILE
):
os
.
remove
(
MODEL_FILE
)
os
.
remove
(
MODEL_FILE
)
...
...
test/ut/compression/v2/test_iterative_pruner_torch.py
View file @
6f3ed2bf
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
random
import
random
import
unittest
import
unittest
import
numpy
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -105,6 +106,17 @@ class IterativePrunerTestCase(unittest.TestCase):
...
@@ -105,6 +106,17 @@ class IterativePrunerTestCase(unittest.TestCase):
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
def
test_amc_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'total_sparsity'
:
0.5
,
'max_sparsity_per_layer'
:
0.8
}]
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
)
ddpg_params
=
{
'hidden1'
:
300
,
'hidden2'
:
300
,
'lr_c'
:
1e-3
,
'lr_a'
:
1e-4
,
'warmup'
:
5
,
'discount'
:
1.
,
'bsize'
:
64
,
'rmsize'
:
100
,
'window_length'
:
1
,
'tau'
:
0.01
,
'init_delta'
:
0.5
,
'delta_decay'
:
0.99
,
'max_episode_length'
:
1e9
,
'epsilon'
:
50000
}
pruner
=
AMCPruner
(
10
,
model
,
config_list
,
dummy_input
,
evaluator
,
finetuner
=
finetuner
,
ddpg_params
=
ddpg_params
,
target
=
'flops'
,
log_dir
=
'../../../logs'
)
pruner
.
compress
()
class
FixSeedPrunerTestCase
(
unittest
.
TestCase
):
def
test_auto_compress_pruner
(
self
):
def
test_auto_compress_pruner
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'total_sparsity'
:
0.8
}]
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'total_sparsity'
:
0.8
}]
...
@@ -126,15 +138,21 @@ class IterativePrunerTestCase(unittest.TestCase):
...
@@ -126,15 +138,21 @@ class IterativePrunerTestCase(unittest.TestCase):
print
(
sparsity_list
)
print
(
sparsity_list
)
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
def
test_amc_pruner
(
self
):
def
setUp
(
self
)
->
None
:
model
=
TorchModel
()
# fix seed in order to solve the random failure of ut
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'total_sparsity'
:
0.5
,
'max_sparsity_per_layer'
:
0.8
}]
random
.
seed
(
1024
)
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
)
numpy
.
random
.
seed
(
1024
)
ddpg_params
=
{
'hidden1'
:
300
,
'hidden2'
:
300
,
'lr_c'
:
1e-3
,
'lr_a'
:
1e-4
,
'warmup'
:
5
,
'discount'
:
1.
,
torch
.
manual_seed
(
1024
)
'bsize'
:
64
,
'rmsize'
:
100
,
'window_length'
:
1
,
'tau'
:
0.01
,
'init_delta'
:
0.5
,
'delta_decay'
:
0.99
,
'max_episode_length'
:
1e9
,
'epsilon'
:
50000
}
def
tearDown
(
self
)
->
None
:
pruner
=
AMCPruner
(
10
,
model
,
config_list
,
dummy_input
,
evaluator
,
finetuner
=
finetuner
,
ddpg_params
=
ddpg_params
,
target
=
'flops'
,
log_dir
=
'../../../logs'
)
# reset seed
pruner
.
compress
()
import
time
now
=
int
(
time
.
time
()
*
100
)
random
.
seed
(
now
)
seed
=
random
.
randint
(
0
,
2
**
32
-
1
)
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
test/ut/compression/v2/test_pruner_torch.py
View file @
6f3ed2bf
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
random
import
unittest
import
unittest
import
numpy
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -122,18 +124,6 @@ class PrunerTestCase(unittest.TestCase):
...
@@ -122,18 +124,6 @@ class PrunerTestCase(unittest.TestCase):
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
def
test_activation_apoz_rank_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
ActivationAPoZRankPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
traced_optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
5
,
activation
=
'relu'
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
def
test_activation_mean_rank_pruner
(
self
):
def
test_activation_mean_rank_pruner
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
...
@@ -177,6 +167,34 @@ class PrunerTestCase(unittest.TestCase):
...
@@ -177,6 +167,34 @@ class PrunerTestCase(unittest.TestCase):
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
class
FixSeedPrunerTestCase
(
unittest
.
TestCase
):
def
test_activation_apoz_rank_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
ActivationAPoZRankPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
traced_optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
5
,
activation
=
'relu'
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.78
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.82
def
setUp
(
self
)
->
None
:
# fix seed in order to solve the random failure of ut
random
.
seed
(
1024
)
numpy
.
random
.
seed
(
1024
)
torch
.
manual_seed
(
1024
)
def
tearDown
(
self
)
->
None
:
# reset seed
import
time
now
=
int
(
time
.
time
()
*
100
)
random
.
seed
(
now
)
seed
=
random
.
randint
(
0
,
2
**
32
-
1
)
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
test/ut/retiarii/debug_mnist_pytorch.py
View file @
6f3ed2bf
...
@@ -16,6 +16,7 @@ class _model(nn.Module):
...
@@ -16,6 +16,7 @@ class _model(nn.Module):
self
.
fc1
=
torch
.
nn
.
Linear
(
out_features
=
256
,
in_features
=
1024
)
self
.
fc1
=
torch
.
nn
.
Linear
(
out_features
=
256
,
in_features
=
1024
)
self
.
fc2
=
torch
.
nn
.
Linear
(
out_features
=
10
,
in_features
=
256
)
self
.
fc2
=
torch
.
nn
.
Linear
(
out_features
=
10
,
in_features
=
256
)
self
.
softmax
=
torch
.
nn
.
Softmax
()
self
.
softmax
=
torch
.
nn
.
Softmax
()
self
.
_mapping_
=
{
'stem'
:
None
,
'flatten'
:
None
,
'fc1'
:
None
,
'fc2'
:
None
,
'softmax'
:
None
}
def
forward
(
self
,
image
):
def
forward
(
self
,
image
):
stem
=
self
.
stem
(
image
)
stem
=
self
.
stem
(
image
)
...
@@ -34,6 +35,7 @@ class stem(nn.Module):
...
@@ -34,6 +35,7 @@ class stem(nn.Module):
self
.
pool1
=
torch
.
nn
.
MaxPool2d
(
kernel_size
=
2
)
self
.
pool1
=
torch
.
nn
.
MaxPool2d
(
kernel_size
=
2
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
=
64
,
in_channels
=
32
,
kernel_size
=
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
=
64
,
in_channels
=
32
,
kernel_size
=
5
)
self
.
pool2
=
torch
.
nn
.
MaxPool2d
(
kernel_size
=
2
)
self
.
pool2
=
torch
.
nn
.
MaxPool2d
(
kernel_size
=
2
)
self
.
_mapping_
=
{
'conv1'
:
None
,
'pool1'
:
None
,
'conv2'
:
None
,
'pool2'
:
None
}
def
forward
(
self
,
*
_inputs
):
def
forward
(
self
,
*
_inputs
):
conv1
=
self
.
conv1
(
_inputs
[
0
])
conv1
=
self
.
conv1
(
_inputs
[
0
])
...
...
test/ut/retiarii/test_cgo_engine.py
View file @
6f3ed2bf
...
@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
...
@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from
pathlib
import
Path
from
pathlib
import
Path
import
nni
import
nni
import
nni.runtime.platform.test
try
:
try
:
from
nni.common.device
import
GPUDevice
from
nni.common.device
import
GPUDevice
...
...
test/ut/retiarii/test_convert.py
View file @
6f3ed2bf
...
@@ -14,6 +14,7 @@ import torchvision
...
@@ -14,6 +14,7 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
basic_unit
from
nni.retiarii
import
basic_unit
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
...
@@ -50,16 +51,6 @@ class Linear(nn.Module):
...
@@ -50,16 +51,6 @@ class Linear(nn.Module):
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
):
def
checkExportImport
(
self
,
model
,
input
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin):
...
@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
with
original_state_dict_hooks
(
converted_model
):
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
dict
(
model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_basic.py
View file @
6f3ed2bf
...
@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit
...
@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
# following pytorch v1.7.1
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
class
TestConvert
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin):
...
@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
converted_state_dict
)
converted_model
.
load_state_dict
(
model
.
state_dict
())
with
torch
.
no_grad
():
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_models.py
View file @
6f3ed2bf
...
@@ -9,23 +9,13 @@ import torch.nn.functional as F
...
@@ -9,23 +9,13 @@ import torch.nn.functional as F
import
torchvision
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestModels
(
unittest
.
TestCase
,
ConvertMixin
):
class
TestModels
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin):
...
@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
converted_state_dict
)
converted_model
.
load_state_dict
(
model
.
state_dict
())
with
torch
.
no_grad
():
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_operators.py
View file @
6f3ed2bf
...
@@ -16,6 +16,7 @@ import torchvision
...
@@ -16,6 +16,7 @@ import torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
...
@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
...
@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class
TestOperators
(
unittest
.
TestCase
,
ConvertMixin
):
class
TestOperators
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
...
@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin):
...
@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
converted_state_dict
)
converted_model
.
load_state_dict
(
model
.
state_dict
())
with
torch
.
no_grad
():
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
...
test/ut/retiarii/test_convert_pytorch.py
View file @
6f3ed2bf
...
@@ -14,28 +14,17 @@ import torch.nn.functional as F
...
@@ -14,28 +14,17 @@ import torch.nn.functional as F
import
torchvision
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
serialize
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
original_state_dict_hooks
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
from
.convert_mixin
import
ConvertMixin
,
ConvertWithShapeMixin
class
TestPytorch
(
unittest
.
TestCase
,
ConvertMixin
):
class
TestPytorch
(
unittest
.
TestCase
,
ConvertMixin
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
run_test
(
self
,
model
,
input
,
check_value
=
True
,
strict_load
=
True
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_ir
=
self
.
_convert_model
(
model
,
input
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
from
.inject_nn
import
remove_inject_pytorch_nn
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
remove_inject_pytorch_nn
()
...
@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
...
@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
exec_vars
=
{}
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
with
original_state_dict_hooks
(
converted_model
):
converted_model
.
load_state_dict
(
converted_state_dict
)
converted_model
.
load_state_dict
(
model
.
state_dict
(),
strict
=
strict_load
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
...
@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
...
@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
model
=
LargeModel
()
model
=
LargeModel
()
x
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
)
x
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
)
self
.
run_test
(
model
,
(
x
,
))
# emb and lin1 is actually not used so they won't appear in generated model
self
.
run_test
(
model
,
(
x
,
),
strict_load
=
False
)
@
unittest
.
skip
(
'skip for now, as it needs inject_nn'
)
@
unittest
.
skip
(
'skip for now, as it needs inject_nn'
)
def
test_mobilenet_v2_with_external_data
(
self
):
def
test_mobilenet_v2_with_external_data
(
self
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment