Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a0fd0036
Unverified
Commit
a0fd0036
authored
Aug 01, 2022
by
Yuge Zhang
Committed by
GitHub
Aug 01, 2022
Browse files
Merge pull request #5036 from microsoft/promote-retiarii-to-nas
[DO NOT SQUASH] Promote retiarii to NAS
parents
d6dcb483
bc6d8796
Changes
239
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
41 additions
and
1116 deletions
+41
-1116
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+2
-179
nni/retiarii/strategy/base.py
nni/retiarii/strategy/base.py
+2
-14
nni/retiarii/strategy/bruteforce.py
nni/retiarii/strategy/bruteforce.py
+2
-133
nni/retiarii/strategy/evolution.py
nni/retiarii/strategy/evolution.py
+2
-165
nni/retiarii/strategy/local_debug_strategy.py
nni/retiarii/strategy/local_debug_strategy.py
+2
-42
nni/retiarii/strategy/oneshot.py
nni/retiarii/strategy/oneshot.py
+2
-18
nni/retiarii/strategy/rl.py
nni/retiarii/strategy/rl.py
+2
-74
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+2
-95
nni/retiarii/strategy/utils.py
nni/retiarii/strategy/utils.py
+2
-54
nni/retiarii/trial_entry.py
nni/retiarii/trial_entry.py
+3
-23
nni/retiarii/utils.py
nni/retiarii/utils.py
+2
-305
pyrightconfig.json
pyrightconfig.json
+2
-2
test/algo/nas/test_cgo_engine.py
test/algo/nas/test_cgo_engine.py
+2
-0
test/algo/nas/test_space_hub.py
test/algo/nas/test_space_hub.py
+2
-2
test/algo/nas/test_strategy.py
test/algo/nas/test_strategy.py
+4
-1
test/ut/nas/debug_mnist_pytorch.py
test/ut/nas/debug_mnist_pytorch.py
+1
-1
test/ut/nas/test_engine.py
test/ut/nas/test_engine.py
+1
-0
test/ut/nas/test_mutator.py
test/ut/nas/test_mutator.py
+4
-6
test/ut/nas/test_nn.py
test/ut/nas/test_nn.py
+2
-2
No files found.
nni/retiarii/serializer.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
inspect
import
os
import
warnings
from
typing
import
Any
,
TypeVar
,
Type
# pylint: disable=wildcard-import,unused-wildcard-import
from
nni.common.serializer
import
is_traceable
,
is_wrapped_with_trace
,
trace
,
_copy_class_wrapper_attributes
from
.utils
import
ModelNamespace
__all__
=
[
'get_init_parameters_or_fail'
,
'serialize'
,
'serialize_cls'
,
'basic_unit'
,
'model_wrapper'
,
'is_basic_unit'
,
'is_model_wrapped'
]
T
=
TypeVar
(
'T'
)
def
get_init_parameters_or_fail
(
obj
:
Any
):
if
is_traceable
(
obj
):
return
obj
.
trace_kwargs
raise
ValueError
(
f
'Object
{
obj
}
needs to be serializable but `trace_kwargs` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use @nni.trace.'
)
def
serialize
(
cls
,
*
args
,
**
kwargs
):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
self.op = serialize(MyCustomOp, hidden_units=128)
"""
warnings
.
warn
(
'nni.retiarii.serialize is deprecated and will be removed in future release. '
+
'Try to use nni.trace, e.g., nni.trace(torch.optim.Adam)(learning_rate=1e-4) instead.'
,
category
=
DeprecationWarning
)
return
trace
(
cls
)(
*
args
,
**
kwargs
)
def
serialize_cls
(
cls
):
"""
To create an serializable class.
"""
warnings
.
warn
(
'nni.retiarii.serialize is deprecated and will be removed in future release. '
+
'Try to use nni.trace instead.'
,
category
=
DeprecationWarning
)
return
trace
(
cls
)
def
basic_unit
(
cls
:
T
,
basic_unit_tag
:
bool
=
True
)
->
T
:
"""
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.
Although ``basic_unit`` calls ``trace`` in its implementation, it is not for serialization. Rather, it is meant
to capture the initialization arguments for mutation. Also, graph execution engine will stop digging into the inner
modules when it reaches a module that is decorated with ``basic_unit``.
.. code-block:: python
@basic_unit
class PrimitiveOp(nn.Module):
...
"""
# Internal flag. See nni.trace
nni_trace_flag
=
os
.
environ
.
get
(
'NNI_TRACE_FLAG'
,
''
)
if
nni_trace_flag
.
lower
()
==
'disable'
:
return
cls
if
_check_wrapped
(
cls
,
'basic_unit'
):
return
cls
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
# type: ignore
cls
=
trace
(
cls
)
cls
.
_nni_basic_unit
=
basic_unit_tag
# type: ignore
_torchscript_patch
(
cls
)
return
cls
def
model_wrapper
(
cls
:
T
)
->
T
:
"""
Wrap the base model (search space). For example,
.. code-block:: python
@model_wrapper
class MyModel(nn.Module):
...
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in namespace so that the auto label counting in each model stably starts from zero.
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
But in future, we might enforce ``@model_wrapper`` to be required for base model.
"""
# Internal flag. See nni.trace
nni_trace_flag
=
os
.
environ
.
get
(
'NNI_TRACE_FLAG'
,
''
)
if
nni_trace_flag
.
lower
()
==
'disable'
:
return
cls
if
_check_wrapped
(
cls
,
'model_wrapper'
):
return
cls
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
)
# type: ignore
# subclass can still use trace info
wrapper
=
trace
(
cls
,
inheritable
=
True
)
class
reset_wrapper
(
wrapper
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_model_namespace
=
ModelNamespace
()
with
self
.
_model_namespace
:
super
().
__init__
(
*
args
,
**
kwargs
)
_copy_class_wrapper_attributes
(
wrapper
,
reset_wrapper
)
reset_wrapper
.
__wrapped__
=
getattr
(
wrapper
,
'__wrapped__'
,
wrapper
)
reset_wrapper
.
_nni_model_wrapper
=
True
reset_wrapper
.
_traced
=
True
_torchscript_patch
(
cls
)
return
reset_wrapper
def
is_basic_unit
(
cls_or_instance
)
->
bool
:
if
not
inspect
.
isclass
(
cls_or_instance
):
cls_or_instance
=
cls_or_instance
.
__class__
return
getattr
(
cls_or_instance
,
'_nni_basic_unit'
,
False
)
def
is_model_wrapped
(
cls_or_instance
)
->
bool
:
if
not
inspect
.
isclass
(
cls_or_instance
):
cls_or_instance
=
cls_or_instance
.
__class__
return
getattr
(
cls_or_instance
,
'_nni_model_wrapper'
,
False
)
def
_check_wrapped
(
cls
:
Type
,
rewrap
:
str
)
->
bool
:
wrapped
=
None
if
is_model_wrapped
(
cls
):
wrapped
=
'model_wrapper'
elif
is_basic_unit
(
cls
):
wrapped
=
'basic_unit'
elif
is_wrapped_with_trace
(
cls
):
wrapped
=
'nni.trace'
if
wrapped
:
if
wrapped
!=
rewrap
:
raise
TypeError
(
f
'
{
cls
}
is already wrapped with
{
wrapped
}
. Cannot rewrap with
{
rewrap
}
.'
)
return
True
return
False
def
_torchscript_patch
(
cls
)
->
None
:
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import
torch
if
hasattr
(
cls
,
'_get_nni_attr'
):
# could not exist on non-linux
cls
.
_get_nni_attr
=
torch
.
jit
.
ignore
(
cls
.
_get_nni_attr
)
if
hasattr
(
cls
,
'trace_symbol'
):
# these must all exist or all non-exist
try
:
cls
.
trace_symbol
=
torch
.
jit
.
unused
(
cls
.
trace_symbol
)
cls
.
trace_args
=
torch
.
jit
.
unused
(
cls
.
trace_args
)
cls
.
trace_kwargs
=
torch
.
jit
.
unused
(
cls
.
trace_kwargs
)
cls
.
trace_copy
=
torch
.
jit
.
ignore
(
cls
.
trace_copy
)
except
AttributeError
as
e
:
if
'property'
in
str
(
e
):
raise
RuntimeError
(
'Trace on PyTorch module failed. Your PyTorch version might be outdated. '
'Please try to upgrade PyTorch.'
)
raise
from
nni.nas.utils.serializer
import
*
nni/retiarii/strategy/base.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
abc
from
typing
import
List
,
Any
# pylint: disable=wildcard-import,unused-wildcard-import
from
..graph
import
Model
from
..mutator
import
Mutator
class
BaseStrategy
(
abc
.
ABC
):
@
abc
.
abstractmethod
def
run
(
self
,
base_model
:
Model
,
applied_mutators
:
List
[
Mutator
])
->
None
:
pass
def
export_top_models
(
self
,
top_k
:
int
)
->
List
[
Any
]:
raise
NotImplementedError
(
'"export_top_models" is not implemented.'
)
from
nni.nas.strategy.base
import
*
nni/retiarii/strategy/bruteforce.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
itertools
import
logging
import
random
import
time
from
typing
import
Any
,
Dict
,
List
,
Sequence
,
Optional
# pylint: disable=wildcard-import,unused-wildcard-import
from
..
import
InvalidMutation
,
Sampler
,
submit_models
,
query_available_resources
,
budget_exhausted
from
.base
import
BaseStrategy
from
.utils
import
dry_run_for_search_space
,
get_targeted_model
,
filter_model
_logger
=
logging
.
getLogger
(
__name__
)
def
grid_generator
(
search_space
:
Dict
[
Any
,
List
[
Any
]],
shuffle
=
True
):
keys
=
list
(
search_space
.
keys
())
search_space_values
=
copy
.
deepcopy
(
list
(
search_space
.
values
()))
if
shuffle
:
for
values
in
search_space_values
:
random
.
shuffle
(
values
)
for
values
in
itertools
.
product
(
*
search_space_values
):
yield
{
key
:
value
for
key
,
value
in
zip
(
keys
,
values
)}
def
random_generator
(
search_space
:
Dict
[
Any
,
List
[
Any
]],
dedup
=
True
,
retries
=
500
):
keys
=
list
(
search_space
.
keys
())
history
=
set
()
search_space_values
=
copy
.
deepcopy
(
list
(
search_space
.
values
()))
while
True
:
selected
:
Optional
[
Sequence
[
int
]]
=
None
for
retry_count
in
range
(
retries
):
selected
=
[
random
.
choice
(
v
)
for
v
in
search_space_values
]
if
not
dedup
:
break
selected
=
tuple
(
selected
)
if
selected
not
in
history
:
history
.
add
(
selected
)
break
if
retry_count
+
1
==
retries
:
_logger
.
debug
(
'Random generation has run out of patience. There is nothing to search. Exiting.'
)
return
assert
selected
is
not
None
,
'Retry attempts exhausted.'
yield
{
key
:
value
for
key
,
value
in
zip
(
keys
,
selected
)}
class
GridSearch
(
BaseStrategy
):
"""
Traverse the search space and try all the possible combinations one by one.
Parameters
----------
shuffle : bool
Shuffle the order in a candidate list, so that they are tried in a random order. Default: true.
"""
def
__init__
(
self
,
shuffle
=
True
):
self
.
_polling_interval
=
2.
self
.
shuffle
=
shuffle
def
run
(
self
,
base_model
,
applied_mutators
):
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
for
sample
in
grid_generator
(
search_space
,
shuffle
=
self
.
shuffle
):
_logger
.
debug
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
while
query_available_resources
()
<=
0
:
if
budget_exhausted
():
return
time
.
sleep
(
self
.
_polling_interval
)
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
class
_RandomSampler
(
Sampler
):
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
return
random
.
choice
(
candidates
)
class
Random
(
BaseStrategy
):
"""
Random search on the search space.
Parameters
----------
variational : bool
Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false.
dedup : bool
Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true.
model_filter: Callable[[Model], bool]
Feed the model and return a bool. This will filter the models in search space and select which to submit.
"""
def
__init__
(
self
,
variational
=
False
,
dedup
=
True
,
model_filter
=
None
):
self
.
variational
=
variational
self
.
dedup
=
dedup
if
variational
and
dedup
:
raise
ValueError
(
'Dedup is not supported in variational mode.'
)
self
.
random_sampler
=
_RandomSampler
()
self
.
_polling_interval
=
2.
self
.
filter
=
model_filter
def
run
(
self
,
base_model
,
applied_mutators
):
if
self
.
variational
:
_logger
.
info
(
'Random search running in variational mode.'
)
sampler
=
_RandomSampler
()
for
mutator
in
applied_mutators
:
mutator
.
bind_sampler
(
sampler
)
while
True
:
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
for
mutator
in
applied_mutators
:
model
=
mutator
.
apply
(
model
)
_logger
.
debug
(
'New model created. Applied mutators are: %s'
,
str
(
applied_mutators
))
if
filter_model
(
self
.
filter
,
model
):
submit_models
(
model
)
elif
budget_exhausted
():
break
else
:
time
.
sleep
(
self
.
_polling_interval
)
else
:
_logger
.
info
(
'Random search running in fixed size mode. Dedup: %s.'
,
'on'
if
self
.
dedup
else
'off'
)
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
for
sample
in
random_generator
(
search_space
,
dedup
=
self
.
dedup
):
_logger
.
debug
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
while
query_available_resources
()
<=
0
:
if
budget_exhausted
():
return
time
.
sleep
(
self
.
_polling_interval
)
_logger
.
debug
(
'Still waiting for resource.'
)
try
:
model
=
get_targeted_model
(
base_model
,
applied_mutators
,
sample
)
if
filter_model
(
self
.
filter
,
model
):
_logger
.
debug
(
'Submitting model: %s'
,
model
)
submit_models
(
model
)
except
InvalidMutation
as
e
:
_logger
.
warning
(
f
'Invalid mutation:
{
e
}
. Skip.'
)
from
nni.nas.strategy.bruteforce
import
*
nni/retiarii/strategy/evolution.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
collections
import
dataclasses
import
logging
import
random
import
time
# pylint: disable=wildcard-import,unused-wildcard-import
from
..execution
import
query_available_resources
,
submit_models
from
..graph
import
ModelStatus
from
.base
import
BaseStrategy
from
.utils
import
dry_run_for_search_space
,
get_targeted_model
,
filter_model
_logger
=
logging
.
getLogger
(
__name__
)
@
dataclasses
.
dataclass
class
Individual
:
"""
A class that represents an individual.
Holds two attributes, where ``x`` is the model and ``y`` is the metric (e.g., accuracy).
"""
x
:
dict
y
:
float
class
RegularizedEvolution
(
BaseStrategy
):
"""
Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Classifier Architecture Search".
Parameters
----------
optimize_mode : str
Can be one of "maximize" and "minimize". Default: maximize.
population_size : int
The number of individuals to keep in the population. Default: 100.
cycles : int
The number of cycles (trials) the algorithm should run for. Default: 20000.
sample_size : int
The number of individuals that should participate in each tournament. Default: 25.
mutation_prob : float
Probability that mutation happens in each dim. Default: 0.05
on_failure : str
Can be one of "ignore" and "worst". If "ignore", simply give up the model and find a new one.
If "worst", mark the model as -inf (if maximize, inf if minimize), so that the algorithm "learns" to avoid such model.
Default: ignore.
model_filter: Callable[[Model], bool]
Feed the model and return a bool. This will filter the models in search space and select which to submit.
"""
def
__init__
(
self
,
optimize_mode
=
'maximize'
,
population_size
=
100
,
sample_size
=
25
,
cycles
=
20000
,
mutation_prob
=
0.05
,
on_failure
=
'ignore'
,
model_filter
=
None
):
assert
optimize_mode
in
[
'maximize'
,
'minimize'
]
assert
on_failure
in
[
'ignore'
,
'worst'
]
assert
sample_size
<
population_size
self
.
optimize_mode
=
optimize_mode
self
.
population_size
=
population_size
self
.
sample_size
=
sample_size
self
.
cycles
=
cycles
self
.
mutation_prob
=
mutation_prob
self
.
on_failure
=
on_failure
self
.
_worst
=
float
(
'-inf'
)
if
self
.
optimize_mode
==
'maximize'
else
float
(
'inf'
)
self
.
_success_count
=
0
self
.
_population
=
collections
.
deque
()
self
.
_running_models
=
[]
self
.
_polling_interval
=
2.
self
.
filter
=
model_filter
def
random
(
self
,
search_space
):
return
{
k
:
random
.
choice
(
v
)
for
k
,
v
in
search_space
.
items
()}
def
mutate
(
self
,
parent
,
search_space
):
child
=
{}
for
k
,
v
in
parent
.
items
():
if
random
.
uniform
(
0
,
1
)
<
self
.
mutation_prob
:
# NOTE: we do not exclude the original choice here for simplicity,
# which is slightly different from the original paper.
child
[
k
]
=
random
.
choice
(
search_space
[
k
])
else
:
child
[
k
]
=
v
return
child
def
best_parent
(
self
):
samples
=
[
p
for
p
in
self
.
_population
]
# copy population
random
.
shuffle
(
samples
)
samples
=
list
(
samples
)[:
self
.
sample_size
]
if
self
.
optimize_mode
==
'maximize'
:
parent
=
max
(
samples
,
key
=
lambda
sample
:
sample
.
y
)
else
:
parent
=
min
(
samples
,
key
=
lambda
sample
:
sample
.
y
)
return
parent
.
x
def
run
(
self
,
base_model
,
applied_mutators
):
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
# Run the first population regardless concurrency
_logger
.
info
(
'Initializing the first population.'
)
while
len
(
self
.
_population
)
+
len
(
self
.
_running_models
)
<=
self
.
population_size
:
# try to submit new models
while
len
(
self
.
_population
)
+
len
(
self
.
_running_models
)
<
self
.
population_size
:
config
=
self
.
random
(
search_space
)
self
.
_submit_config
(
config
,
base_model
,
applied_mutators
)
# collect results
self
.
_move_succeeded_models_to_population
()
self
.
_remove_failed_models_from_running_list
()
time
.
sleep
(
self
.
_polling_interval
)
if
len
(
self
.
_population
)
>=
self
.
population_size
:
break
# Resource-aware mutation of models
_logger
.
info
(
'Running mutations.'
)
while
self
.
_success_count
+
len
(
self
.
_running_models
)
<=
self
.
cycles
:
# try to submit new models
while
query_available_resources
()
>
0
and
self
.
_success_count
+
len
(
self
.
_running_models
)
<
self
.
cycles
:
config
=
self
.
mutate
(
self
.
best_parent
(),
search_space
)
self
.
_submit_config
(
config
,
base_model
,
applied_mutators
)
# collect results
self
.
_move_succeeded_models_to_population
()
self
.
_remove_failed_models_from_running_list
()
time
.
sleep
(
self
.
_polling_interval
)
if
self
.
_success_count
>=
self
.
cycles
:
break
def
_submit_config
(
self
,
config
,
base_model
,
mutators
):
_logger
.
debug
(
'Model submitted to running queue: %s'
,
config
)
model
=
get_targeted_model
(
base_model
,
mutators
,
config
)
if
not
filter_model
(
self
.
filter
,
model
):
if
self
.
on_failure
==
"worst"
:
model
.
status
=
ModelStatus
.
Failed
self
.
_running_models
.
append
((
config
,
model
))
else
:
submit_models
(
model
)
self
.
_running_models
.
append
((
config
,
model
))
return
model
def
_move_succeeded_models_to_population
(
self
):
completed_indices
=
[]
for
i
,
(
config
,
model
)
in
enumerate
(
self
.
_running_models
):
metric
=
None
if
self
.
on_failure
==
'worst'
and
model
.
status
==
ModelStatus
.
Failed
:
metric
=
self
.
_worst
elif
model
.
status
==
ModelStatus
.
Trained
:
metric
=
model
.
metric
if
metric
is
not
None
:
individual
=
Individual
(
config
,
metric
)
_logger
.
debug
(
'Individual created: %s'
,
str
(
individual
))
self
.
_population
.
append
(
individual
)
if
len
(
self
.
_population
)
>
self
.
population_size
:
self
.
_population
.
popleft
()
completed_indices
.
append
(
i
)
for
i
in
completed_indices
[::
-
1
]:
# delete from end to start so that the index number will not be affected.
self
.
_success_count
+=
1
self
.
_running_models
.
pop
(
i
)
def
_remove_failed_models_from_running_list
(
self
):
# This is only done when on_failure policy is set to "ignore".
# Otherwise, failed models will be treated as inf when processed.
if
self
.
on_failure
==
'ignore'
:
number_of_failed_models
=
len
([
g
for
g
in
self
.
_running_models
if
g
[
1
].
status
==
ModelStatus
.
Failed
])
self
.
_running_models
=
[
g
for
g
in
self
.
_running_models
if
g
[
1
].
status
!=
ModelStatus
.
Failed
]
if
number_of_failed_models
>
0
:
_logger
.
info
(
'%d failed models are ignored. Will retry.'
,
number_of_failed_models
)
from
nni.nas.strategy.evolution
import
*
nni/retiarii/strategy/local_debug_strategy.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
os
import
random
import
string
# pylint: disable=wildcard-import,unused-wildcard-import
from
..
import
Sampler
,
codegen
,
utils
from
..execution.base
import
BaseGraphData
from
..execution.utils
import
get_mutation_summary
from
.base
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
class
ChooseFirstSampler
(
Sampler
):
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
return
candidates
[
0
]
class
_LocalDebugStrategy
(
BaseStrategy
):
"""
This class is supposed to be used internally, for debugging trial mutation
"""
def
run_one_model
(
self
,
model
):
mutation_summary
=
get_mutation_summary
(
model
)
graph_data
=
BaseGraphData
(
codegen
.
pytorch
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
# type: ignore
random_str
=
''
.
join
(
random
.
choice
(
string
.
ascii_uppercase
+
string
.
digits
)
for
_
in
range
(
6
))
file_name
=
f
'_generated_model/
{
random_str
}
.py'
os
.
makedirs
(
os
.
path
.
dirname
(
file_name
),
exist_ok
=
True
)
with
open
(
file_name
,
'w'
)
as
f
:
f
.
write
(
graph_data
.
model_script
)
model_cls
=
utils
.
import_
(
f
'_generated_model.
{
random_str
}
._model'
)
graph_data
.
evaluator
.
_execute
(
model_cls
)
os
.
remove
(
file_name
)
def
run
(
self
,
base_model
,
applied_mutators
):
_logger
.
info
(
'local debug strategy has been started.'
)
model
=
base_model
_logger
.
debug
(
'New model created. Applied mutators: %s'
,
str
(
applied_mutators
))
choose_first_sampler
=
ChooseFirstSampler
()
for
mutator
in
applied_mutators
:
mutator
.
bind_sampler
(
choose_first_sampler
)
model
=
mutator
.
apply
(
model
)
# directly run models
self
.
run_one_model
(
model
)
from
nni.nas.strategy.debug
import
*
nni/retiarii/strategy/oneshot.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.base
import
BaseStrategy
# pylint: disable=wildcard-import,unused-wildcard-import
try
:
from
nni.retiarii.oneshot.pytorch.strategy
import
(
# pylint: disable=unused-import
DARTS
,
GumbelDARTS
,
Proxyless
,
ENAS
,
RandomOneShot
)
except
ImportError
as
import_err
:
_import_err
=
import_err
class
ImportFailedStrategy
(
BaseStrategy
):
def
run
(
self
,
base_model
,
applied_mutators
):
raise
_import_err
# otherwise typing check will pointing to the wrong location
globals
()[
'DARTS'
]
=
ImportFailedStrategy
globals
()[
'GumbelDARTS'
]
=
ImportFailedStrategy
globals
()[
'Proxyless'
]
=
ImportFailedStrategy
globals
()[
'ENAS'
]
=
ImportFailedStrategy
globals
()[
'RandomOneShot'
]
=
ImportFailedStrategy
from
nni.nas.strategy.oneshot
import
*
nni/retiarii/strategy/rl.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
typing
import
Optional
,
Callable
# pylint: disable=wildcard-import,unused-wildcard-import
from
.base
import
BaseStrategy
from
.utils
import
dry_run_for_search_space
from
..execution
import
query_available_resources
try
:
has_tianshou
=
True
import
torch
from
tianshou.data
import
Collector
,
VectorReplayBuffer
from
tianshou.env
import
BaseVectorEnv
from
tianshou.policy
import
BasePolicy
,
PPOPolicy
# pylint: disable=unused-import
from
._rl_impl
import
ModelEvaluationEnv
,
MultiThreadEnvWorker
,
Preprocessor
,
Actor
,
Critic
except
ImportError
:
has_tianshou
=
False
_logger
=
logging
.
getLogger
(
__name__
)
class
PolicyBasedRL
(
BaseStrategy
):
"""
Algorithm for policy-based reinforcement learning.
This is a wrapper of algorithms provided in tianshou (PPO by default),
and can be easily customized with other algorithms that inherit ``BasePolicy``
(e.g., `REINFORCE <https://link.springer.com/content/pdf/10.1007/BF00992696.pdf>`__
as in `this paper <https://arxiv.org/abs/1611.01578>`__).
Parameters
----------
max_collect : int
How many times collector runs to collect trials for RL. Default 100.
trial_per_collect : int
How many trials (trajectories) each time collector collects.
After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
policy_fn : function
Takes :class:`ModelEvaluationEnv` as input and return a policy.
See :meth:`PolicyBasedRL._default_policy_fn` for an example.
"""
def
__init__
(
self
,
max_collect
:
int
=
100
,
trial_per_collect
=
20
,
policy_fn
:
Optional
[
Callable
[[
'ModelEvaluationEnv'
],
'BasePolicy'
]]
=
None
):
if
not
has_tianshou
:
raise
ImportError
(
'`tianshou` is required to run RL-based strategy. '
'Please use "pip install tianshou" to install it beforehand.'
)
self
.
policy_fn
=
policy_fn
or
self
.
_default_policy_fn
self
.
max_collect
=
max_collect
self
.
trial_per_collect
=
trial_per_collect
@
staticmethod
def
_default_policy_fn
(
env
):
net
=
Preprocessor
(
env
.
observation_space
)
actor
=
Actor
(
env
.
action_space
,
net
)
critic
=
Critic
(
net
)
optim
=
torch
.
optim
.
Adam
(
set
(
actor
.
parameters
()).
union
(
critic
.
parameters
()),
lr
=
1e-4
)
return
PPOPolicy
(
actor
,
critic
,
optim
,
torch
.
distributions
.
Categorical
,
discount_factor
=
1.
,
action_space
=
env
.
action_space
)
def
run
(
self
,
base_model
,
applied_mutators
):
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
concurrency
=
query_available_resources
()
env_fn
=
lambda
:
ModelEvaluationEnv
(
base_model
,
applied_mutators
,
search_space
)
policy
=
self
.
policy_fn
(
env_fn
())
env
=
BaseVectorEnv
([
env_fn
for
_
in
range
(
concurrency
)],
MultiThreadEnvWorker
)
collector
=
Collector
(
policy
,
env
,
VectorReplayBuffer
(
20000
,
len
(
env
)))
for
cur_collect
in
range
(
1
,
self
.
max_collect
+
1
):
_logger
.
info
(
'Collect [%d] Running...'
,
cur_collect
)
result
=
collector
.
collect
(
n_episode
=
self
.
trial_per_collect
)
_logger
.
info
(
'Collect [%d] Result: %s'
,
cur_collect
,
str
(
result
))
policy
.
update
(
0
,
collector
.
buffer
,
batch_size
=
64
,
repeat
=
5
)
from
nni.nas.strategy.rl
import
*
nni/retiarii/strategy/tpe_strategy.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
time
from
typing
import
Optional
# pylint: disable=wildcard-import,unused-wildcard-import
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
,
budget_exhausted
from
.base
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
class
TPESampler
(
Sampler
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
# Move import here to eliminate some warning messages about dill.
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
self
.
cur_sample
:
Optional
[
dict
]
=
None
self
.
index
:
Optional
[
int
]
=
None
self
.
total_parameters
=
{}
def
update_sample_space
(
self
,
sample_space
):
search_space
=
{}
for
i
,
each
in
enumerate
(
sample_space
):
search_space
[
str
(
i
)]
=
{
'_type'
:
'choice'
,
'_value'
:
each
}
self
.
tpe_tuner
.
update_search_space
(
search_space
)
def
generate_samples
(
self
,
model_id
):
self
.
cur_sample
=
self
.
tpe_tuner
.
generate_parameters
(
model_id
)
self
.
total_parameters
[
model_id
]
=
self
.
cur_sample
self
.
index
=
0
def
receive_result
(
self
,
model_id
,
result
):
self
.
tpe_tuner
.
receive_trial_result
(
model_id
,
self
.
total_parameters
[
model_id
],
result
)
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
assert
isinstance
(
self
.
index
,
int
)
and
isinstance
(
self
.
cur_sample
,
dict
)
chosen
=
self
.
cur_sample
[
str
(
self
.
index
)]
self
.
index
+=
1
return
chosen
class
TPE
(
BaseStrategy
):
"""
The Tree-structured Parzen Estimator (TPE) is a sequential model-based optimization (SMBO) approach.
Find the details in
`Algorithms for Hyper-Parameter Optimization <https://papers.nips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf>`__.
SMBO methods sequentially construct models to approximate the performance of hyperparameters based on historical measurements,
and then subsequently choose new hyperparameters to test based on this model.
"""
def
__init__
(
self
):
self
.
tpe_sampler
=
TPESampler
()
self
.
model_id
=
0
self
.
running_models
=
{}
def
run
(
self
,
base_model
,
applied_mutators
):
sample_space
=
[]
new_model
=
base_model
for
mutator
in
applied_mutators
:
recorded_candidates
,
new_model
=
mutator
.
dry_run
(
new_model
)
sample_space
.
extend
(
recorded_candidates
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
_logger
.
info
(
'TPE strategy has been started.'
)
while
not
budget_exhausted
():
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
_logger
.
debug
(
'New model created. Applied mutators: %s'
,
str
(
applied_mutators
))
self
.
tpe_sampler
.
generate_samples
(
self
.
model_id
)
for
mutator
in
applied_mutators
:
mutator
.
bind_sampler
(
self
.
tpe_sampler
)
model
=
mutator
.
apply
(
model
)
# run models
submit_models
(
model
)
self
.
running_models
[
self
.
model_id
]
=
model
self
.
model_id
+=
1
else
:
time
.
sleep
(
2
)
_logger
.
debug
(
'num of running models: %d'
,
len
(
self
.
running_models
))
to_be_deleted
=
[]
for
_id
,
_model
in
self
.
running_models
.
items
():
if
is_stopped_exec
(
_model
):
if
_model
.
metric
is
not
None
:
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
_logger
.
debug
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
to_be_deleted
.
append
(
_id
)
for
_id
in
to_be_deleted
:
del
self
.
running_models
[
_id
]
# alias for backward compatibility
TPEStrategy
=
TPE
from
nni.nas.strategy.hpo
import
*
nni/retiarii/strategy/utils.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
collections
import
logging
from
typing
import
Dict
,
Any
,
List
from
..graph
import
Model
from
..mutator
import
Mutator
,
Sampler
# pylint: disable=wildcard-import,unused-wildcard-import
_logger
=
logging
.
getLogger
(
__name__
)
class
_FixedSampler
(
Sampler
):
def
__init__
(
self
,
sample
):
self
.
sample
=
sample
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
return
self
.
sample
[(
mutator
,
index
)]
def
dry_run_for_search_space
(
model
:
Model
,
mutators
:
List
[
Mutator
])
->
Dict
[
Any
,
List
[
Any
]]:
search_space
=
collections
.
OrderedDict
()
for
mutator
in
mutators
:
recorded_candidates
,
model
=
mutator
.
dry_run
(
model
)
for
i
,
candidates
in
enumerate
(
recorded_candidates
):
search_space
[(
mutator
,
i
)]
=
candidates
return
search_space
def
dry_run_for_formatted_search_space
(
model
:
Model
,
mutators
:
List
[
Mutator
])
->
Dict
[
Any
,
Dict
[
Any
,
Any
]]:
search_space
=
collections
.
OrderedDict
()
for
mutator
in
mutators
:
recorded_candidates
,
model
=
mutator
.
dry_run
(
model
)
if
len
(
recorded_candidates
)
==
1
:
search_space
[
mutator
.
label
]
=
{
'_type'
:
'choice'
,
'_value'
:
recorded_candidates
[
0
]}
else
:
for
i
,
candidate
in
enumerate
(
recorded_candidates
):
search_space
[
f
'
{
mutator
.
label
}
_
{
i
}
'
]
=
{
'_type'
:
'choice'
,
'_value'
:
candidate
}
return
search_space
def
get_targeted_model
(
base_model
:
Model
,
mutators
:
List
[
Mutator
],
sample
:
dict
)
->
Model
:
sampler
=
_FixedSampler
(
sample
)
model
=
base_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
return
model
def
filter_model
(
model_filter
,
ir_model
):
if
model_filter
is
not
None
:
_logger
.
debug
(
f
'Check if model satisfies constraints.'
)
if
model_filter
(
ir_model
):
_logger
.
debug
(
f
'Model satisfied. Submit the model.'
)
return
True
else
:
_logger
.
debug
(
f
'Model unsatisfied. Discard the model.'
)
return
False
else
:
return
True
from
nni.nas.strategy.utils
import
*
nni/retiarii/trial_entry.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Entrypoint for trials.
# pylint: disable=wildcard-import,unused-wildcard-import
Assuming execution engine is BaseExecutionEngine.
"""
import
argparse
from
nni.nas.execution.trial_entry
import
main
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'exec'
,
choices
=
[
'base'
,
'py'
,
'cgo'
,
'benchmark'
])
args
=
parser
.
parse_args
()
if
args
.
exec
==
'base'
:
from
.execution.base
import
BaseExecutionEngine
engine
=
BaseExecutionEngine
elif
args
.
exec
==
'cgo'
:
from
.execution.cgo_engine
import
CGOExecutionEngine
engine
=
CGOExecutionEngine
elif
args
.
exec
==
'py'
:
from
.execution.python
import
PurePythonExecutionEngine
engine
=
PurePythonExecutionEngine
elif
args
.
exec
==
'benchmark'
:
from
.execution.benchmark
import
BenchmarkExecutionEngine
engine
=
BenchmarkExecutionEngine
else
:
raise
ValueError
(
f
'Unrecognized benchmark name:
{
args
.
exec
}
'
)
engine
.
trial_execute_graph
()
main
()
nni/retiarii/utils.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
inspect
import
itertools
import
warnings
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Dict
,
cast
from
pathlib
import
Path
# pylint: disable=wildcard-import,unused-wildcard-import
from
nni.common.hpo_utils
import
ParameterSpec
__all__
=
[
'NoContextError'
,
'ContextStack'
,
'ModelNamespace'
,
'original_state_dict_hooks'
]
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
Any
:
if
target
is
None
:
return
None
path
,
identifier
=
target
.
rsplit
(
'.'
,
1
)
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
return
getattr
(
module
,
identifier
)
_last_uid
=
defaultdict
(
int
)
_DEFAULT_MODEL_NAMESPACE
=
'model'
def
uid
(
namespace
:
str
=
'default'
)
->
int
:
_last_uid
[
namespace
]
+=
1
return
_last_uid
[
namespace
]
def
reset_uid
(
namespace
:
str
=
'default'
)
->
None
:
_last_uid
[
namespace
]
=
0
def
get_module_name
(
cls_or_func
):
module_name
=
cls_or_func
.
__module__
if
module_name
==
'__main__'
:
# infer the module name with inspect
for
frm
in
inspect
.
stack
():
module
=
inspect
.
getmodule
(
frm
[
0
])
if
module
is
not
None
and
module
.
__name__
==
'__main__'
:
# main module found
main_file_path
=
Path
(
cast
(
str
,
inspect
.
getsourcefile
(
frm
[
0
])))
if
not
Path
().
samefile
(
main_file_path
.
parent
):
raise
RuntimeError
(
f
'You are using "
{
main_file_path
}
" to launch your experiment, '
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
module_name
=
main_file_path
.
stem
break
if
module_name
==
'__main__'
:
warnings
.
warn
(
'Callstack exhausted but main module still not found. This will probably cause issues that the '
'function/class cannot be imported.'
)
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if
f
'
{
cls_or_func
.
__module__
}
.
{
cls_or_func
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
module_name
=
cls_or_func
.
__module__
return
module_name
def
get_importable_name
(
cls
,
relocate_module
=
False
):
module_name
=
get_module_name
(
cls
)
if
relocate_module
else
cls
.
__module__
return
module_name
+
'.'
+
cls
.
__name__
class
NoContextError
(
Exception
):
"""Exception raised when context is missing."""
pass
class
ContextStack
:
"""
This is to maintain a globally-accessible context environment that is visible to everywhere.
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace.
Note that this is not multi-processing safe. Also, the values will get cleared for a new process.
"""
_stack
:
Dict
[
str
,
List
[
Any
]]
=
defaultdict
(
list
)
def
__init__
(
self
,
key
:
str
,
value
:
Any
):
self
.
key
=
key
self
.
value
=
value
def
__enter__
(
self
):
self
.
push
(
self
.
key
,
self
.
value
)
return
self
def
__exit__
(
self
,
*
args
,
**
kwargs
):
self
.
pop
(
self
.
key
)
@
classmethod
def
push
(
cls
,
key
:
str
,
value
:
Any
):
cls
.
_stack
[
key
].
append
(
value
)
@
classmethod
def
pop
(
cls
,
key
:
str
)
->
None
:
cls
.
_stack
[
key
].
pop
()
@
classmethod
def
top
(
cls
,
key
:
str
)
->
Any
:
if
not
cls
.
_stack
[
key
]:
raise
NoContextError
(
'Context is empty.'
)
return
cls
.
_stack
[
key
][
-
1
]
class
ModelNamespace
:
"""
To create an individual namespace for models:
1. to enable automatic numbering;
2. to trace general information (like creation of hyper-parameters) of model.
A namespace is bounded to a key. Namespace bounded to different keys are completed isolated.
Namespace can have sub-namespaces (with the same key). The numbering will be chained (e.g., ``model_1_4_2``).
"""
def
__init__
(
self
,
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
):
# for example, key: "model_wrapper"
self
.
key
=
key
# the "path" of current name
# By default, it's ``[]``
# If a ``@model_wrapper`` is nested inside a model_wrapper, it will become something like ``[1, 3, 2]``.
# See ``__enter__``.
self
.
name_path
:
List
[
int
]
=
[]
# parameter specs.
# Currently only used trace calls of ModelParameterChoice.
self
.
parameter_specs
:
List
[
ParameterSpec
]
=
[]
def
__enter__
(
self
):
# For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
# the next thing up is [1, 2, 2, 4].
# `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
try
:
parent_context
:
'ModelNamespace'
=
ModelNamespace
.
current_context
(
self
.
key
)
next_uid
=
uid
(
parent_context
.
_simple_name
())
self
.
name_path
=
parent_context
.
name_path
+
[
next_uid
]
ContextStack
.
push
(
self
.
key
,
self
)
reset_uid
(
self
.
_simple_name
())
except
NoContextError
:
# not found, no existing namespace
self
.
name_path
=
[]
ContextStack
.
push
(
self
.
key
,
self
)
reset_uid
(
self
.
_simple_name
())
def
__exit__
(
self
,
*
args
,
**
kwargs
):
ContextStack
.
pop
(
self
.
key
)
def
_simple_name
(
self
)
->
str
:
return
self
.
key
+
''
.
join
([
'_'
+
str
(
k
)
for
k
in
self
.
name_path
])
def
__repr__
(
self
):
return
f
'ModelNamespace(name=
{
self
.
_simple_name
()
}
, num_specs=
{
len
(
self
.
parameter_specs
)
}
)'
# Access the current context in the model #
@
staticmethod
def
current_context
(
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
)
->
'ModelNamespace'
:
"""Get the current context in key."""
try
:
return
ContextStack
.
top
(
key
)
except
NoContextError
:
raise
NoContextError
(
'ModelNamespace context is missing. You might have forgotten to use `@model_wrapper`.'
)
@
staticmethod
def
next_label
(
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
)
->
str
:
"""Get the next label for API calls, with automatic numbering."""
try
:
current_context
=
ContextStack
.
top
(
key
)
except
NoContextError
:
# fallback to use "default" namespace
# it won't be registered
warnings
.
warn
(
'ModelNamespace is missing. You might have forgotten to use `@model_wrapper`. '
'Some features might not work. This will be an error in future releases.'
,
RuntimeWarning
)
current_context
=
ModelNamespace
(
'default'
)
next_uid
=
uid
(
current_context
.
_simple_name
())
return
current_context
.
_simple_name
()
+
'_'
+
str
(
next_uid
)
def
get_current_context
(
key
:
str
)
->
Any
:
return
ContextStack
.
top
(
key
)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING
=
'_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL
=
'_mapping_partial_'
@
contextmanager
def
original_state_dict_hooks
(
model
:
Any
):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import
torch.utils.hooks
import
torch.nn
as
nn
assert
isinstance
(
model
,
nn
.
Module
),
'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping
=
{}
def
full_mapping_in_module
(
src_prefix
,
tar_prefix
,
module
):
if
hasattr
(
module
,
STATE_DICT_PY_MAPPING
):
# only values are complete
local_map
=
getattr
(
module
,
STATE_DICT_PY_MAPPING
)
elif
hasattr
(
module
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# keys and values are both incomplete
local_map
=
getattr
(
module
,
STATE_DICT_PY_MAPPING_PARTIAL
)
local_map
=
{
k
:
tar_prefix
+
v
for
k
,
v
in
local_map
.
items
()}
else
:
# no mapping
local_map
=
{}
if
'__self__'
in
local_map
:
# special case, overwrite prefix
tar_prefix
=
local_map
[
'__self__'
]
+
'.'
for
key
,
value
in
local_map
.
items
():
if
key
!=
''
and
key
not
in
module
.
_modules
:
# not a sub-module, probably a parameter
full_mapping
[
src_prefix
+
key
]
=
value
if
src_prefix
!=
tar_prefix
:
# To deal with leaf nodes.
for
name
,
value
in
itertools
.
chain
(
module
.
_parameters
.
items
(),
module
.
_buffers
.
items
()):
# direct children
if
value
is
None
or
name
in
module
.
_non_persistent_buffers_set
:
# it won't appear in state dict
continue
if
(
src_prefix
+
name
)
not
in
full_mapping
:
full_mapping
[
src_prefix
+
name
]
=
tar_prefix
+
name
for
name
,
child
in
module
.
named_children
():
# sub-modules
full_mapping_in_module
(
src_prefix
+
name
+
'.'
,
local_map
.
get
(
name
,
tar_prefix
+
name
)
+
'.'
,
# if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module
(
''
,
''
,
model
)
def
load_state_dict_hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
reverse_mapping
=
defaultdict
(
list
)
for
src
,
tar
in
full_mapping
.
items
():
reverse_mapping
[
tar
].
append
(
src
)
transf_state_dict
=
{}
for
src
,
tar_keys
in
reverse_mapping
.
items
():
if
src
in
state_dict
:
value
=
state_dict
.
pop
(
src
)
for
tar
in
tar_keys
:
transf_state_dict
[
tar
]
=
value
else
:
missing_keys
.
append
(
src
)
state_dict
.
update
(
transf_state_dict
)
def
state_dict_hook
(
module
,
destination
,
prefix
,
local_metadata
):
result
=
{}
for
src
,
tar
in
full_mapping
.
items
():
if
src
in
destination
:
result
[
tar
]
=
destination
.
pop
(
src
)
else
:
raise
KeyError
(
f
'"
{
src
}
" not in state dict, but found in mapping.'
)
destination
.
update
(
result
)
hooks
:
List
[
torch
.
utils
.
hooks
.
RemovableHandle
]
=
[]
try
:
hooks
.
append
(
model
.
_register_load_state_dict_pre_hook
(
load_state_dict_hook
))
hooks
.
append
(
model
.
_register_state_dict_hook
(
state_dict_hook
))
yield
finally
:
for
hook
in
hooks
:
hook
.
remove
()
from
nni.nas.utils.misc
import
*
pyrightconfig.json
View file @
a0fd0036
...
...
@@ -10,8 +10,8 @@
"nni/common/device.py"
,
"nni/common/graph_utils.py"
,
"nni/compression"
,
"nni/nas/
tensorflow
"
,
"nni/nas/pytorch"
,
"nni/nas/
execution/pytorch/cgo
"
,
"nni/nas/
evaluator/
pytorch
/cgo
"
,
"nni/retiarii/execution/cgo_engine.py"
,
"nni/retiarii/execution/logical_optimizer"
,
"nni/retiarii/evaluator/pytorch/cgo"
,
...
...
test/algo/nas/test_cgo_engine.py
View file @
a0fd0036
...
...
@@ -32,6 +32,8 @@ try:
from
nni.retiarii.evaluator.pytorch.cgo.evaluator
import
MultiModelSupervisedLearningModule
,
_MultiModelSupervisedLearningModule
import
nni.retiarii.evaluator.pytorch.cgo.trainer
as
cgo_trainer
import
nni.retiarii.integration_api
module_import_failed
=
False
except
ImportError
:
module_import_failed
=
True
...
...
test/algo/nas/test_space_hub.py
View file @
a0fd0036
...
...
@@ -14,7 +14,7 @@ import nni.runtime.platform.test
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
nni.retiarii.hub.pytorch
as
searchspace
from
nni.retiarii
import
fixed_arch
from
nni.retiarii.execution.utils
import
_
unpack_if_only_one
from
nni.retiarii.execution.utils
import
unpack_if_only_one
from
nni.retiarii.mutator
import
InvalidMutation
,
Sampler
from
nni.retiarii.nn.pytorch.mutator
import
extract_mutation_from_pt_module
...
...
@@ -58,7 +58,7 @@ def _test_searchspace_on_dataset(searchspace, dataset='cifar10', arch=None):
if
arch
is
None
:
model
=
try_mutation_until_success
(
model
,
mutators
,
10
)
arch
=
{
mut
.
mutator
.
label
:
_
unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model
.
history
}
arch
=
{
mut
.
mutator
.
label
:
unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model
.
history
}
print
(
'Selected model:'
,
arch
)
with
fixed_arch
(
arch
):
...
...
test/algo/nas/test_strategy.py
View file @
a0fd0036
...
...
@@ -56,7 +56,10 @@ class MockExecutionEngine(AbstractExecutionEngine):
def
_reset_execution_engine
(
engine
=
None
):
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
engine
# Use the new NAS reset
# nni.retiarii.execution.api._execution_engine = engine
import
nni.nas.execution.api
nni
.
nas
.
execution
.
api
.
_execution_engine
=
engine
class
Net
(
nn
.
Module
):
...
...
test/ut/nas/debug_mnist_pytorch.py
View file @
a0fd0036
...
...
@@ -3,7 +3,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
nni.
retiarii
.nn.pytorch
import
nni.
nas
.nn.pytorch
import
torch
...
...
test/ut/nas/test_engine.py
View file @
a0fd0036
...
...
@@ -4,6 +4,7 @@ import unittest
from
pathlib
import
Path
import
nni.retiarii
import
nni.retiarii.integration_api
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.execution
import
set_execution_engine
...
...
test/ut/nas/test_mutator.py
View file @
a0fd0036
import
json
from
pathlib
import
Path
import
sys
from
nni.common.framework
import
get_default_framework
,
set_default_framework
from
nni.retiarii
import
*
# FIXME
import
nni.retiarii.debug_configs
original_framework
=
nni
.
retiarii
.
debug_configs
.
framework
original_framework
=
get_default_framework
()
max_pool
=
Operation
.
new
(
'MaxPool2D'
,
{
'pool_size'
:
2
})
avg_pool
=
Operation
.
new
(
'AveragePooling2D'
,
{
'pool_size'
:
2
})
...
...
@@ -14,11 +12,11 @@ global_pool = Operation.new('GlobalAveragePooling2D')
def
setup_module
(
module
):
nni
.
retiarii
.
debug_configs
.
framework
=
'tensorflow'
set_default_
framework
(
'tensorflow'
)
def
teardown_module
(
module
):
nni
.
retiarii
.
debug_configs
.
framework
=
original_framework
set_default_
framework
(
original_framework
)
class
DebugSampler
(
Sampler
):
...
...
test/ut/nas/test_nn.py
View file @
a0fd0036
...
...
@@ -15,7 +15,7 @@ from nni.retiarii import InvalidMutation, Sampler, basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.evaluator
import
FunctionalEvaluator
from
nni.retiarii.execution.utils
import
_
unpack_if_only_one
from
nni.retiarii.execution.utils
import
unpack_if_only_one
from
nni.retiarii.experiment.pytorch
import
preprocess_model
from
nni.retiarii.graph
import
Model
from
nni.retiarii.nn.pytorch.api
import
ValueChoice
...
...
@@ -827,7 +827,7 @@ class Python(GraphIR):
graph_engine
=
False
def
_get_converted_pytorch_model
(
self
,
model_ir
):
mutation
=
{
mut
.
mutator
.
label
:
_
unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model_ir
.
history
}
mutation
=
{
mut
.
mutator
.
label
:
unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model_ir
.
history
}
with
ContextStack
(
'fixed'
,
mutation
):
model
=
model_ir
.
python_class
(
**
model_ir
.
python_init_params
)
return
model
...
...
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment