Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5b7dac5c
Unverified
Commit
5b7dac5c
authored
Mar 02, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 02, 2022
Browse files
Wrap one-shot algorithms as strategies (#4571)
parent
c13392ab
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
562 additions
and
205 deletions
+562
-205
dependencies/develop.txt
dependencies/develop.txt
+1
-0
docs/source/NAS/ApiReference.rst
docs/source/NAS/ApiReference.rst
+16
-28
docs/source/conf.py
docs/source/conf.py
+4
-0
docs/source/refs.bib
docs/source/refs.bib
+29
-0
docs/static/css/material_custom.css
docs/static/css/material_custom.css
+5
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+125
-22
nni/retiarii/graph.py
nni/retiarii/graph.py
+3
-0
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+14
-0
nni/retiarii/oneshot/pytorch/__init__.py
nni/retiarii/oneshot/pytorch/__init__.py
+2
-2
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+46
-21
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+93
-59
nni/retiarii/oneshot/pytorch/sampling.py
nni/retiarii/oneshot/pytorch/sampling.py
+47
-38
nni/retiarii/oneshot/pytorch/strategy.py
nni/retiarii/oneshot/pytorch/strategy.py
+129
-0
nni/retiarii/strategy/__init__.py
nni/retiarii/strategy/__init__.py
+1
-0
nni/retiarii/strategy/base.py
nni/retiarii/strategy/base.py
+4
-1
nni/retiarii/strategy/oneshot.py
nni/retiarii/strategy/oneshot.py
+22
-0
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+21
-34
No files found.
dependencies/develop.txt
View file @
5b7dac5c
...
...
@@ -12,4 +12,5 @@ rstcheck
sphinx
sphinx-argparse-nni >= 0.4.0
sphinx-gallery
sphinxcontrib-bibtex
git+https://github.com/bashtage/sphinx-material.git
docs/source/NAS/ApiReference.rst
View file @
5b7dac5c
...
...
@@ -60,38 +60,12 @@ Evaluators
.. autoclass:: nni.retiarii.evaluator.pytorch.lightning.Regression
:members:
Oneshot Trainers
----------------
.. autoclass:: nni.retiarii.oneshot.pytorch.DartsTrainer
:members:
.. autoclass:: nni.retiarii.oneshot.pytorch.EnasTrainer
:members:
.. autoclass:: nni.retiarii.oneshot.pytorch.ProxylessTrainer
:members:
.. autoclass:: nni.retiarii.oneshot.pytorch.SinglePathTrainer
:members:
Exploration Strategies
----------------------
.. autoclass:: nni.retiarii.strategy.Random
:members:
.. autoclass:: nni.retiarii.strategy.GridSearch
:members:
.. autoclass:: nni.retiarii.strategy.RegularizedEvolution
:members:
.. autoclass:: nni.retiarii.strategy.TPEStrategy
:members:
.. autoclass:: nni.retiarii.strategy.PolicyBasedRL
.. automodule:: nni.retiarii.strategy
:members:
:imported-members:
Retiarii Experiments
--------------------
...
...
@@ -111,6 +85,17 @@ CGO Execution
.. autofunction:: nni.retiarii.evaluator.pytorch.cgo.evaluator.Regression
One-shot Implementation
-----------------------
.. automodule:: nni.retiarii.oneshot
:members:
:imported-members:
.. automodule:: nni.retiarii.oneshot.pytorch
:members:
:imported-members:
Utilities
---------
...
...
@@ -120,4 +105,7 @@ Utilities
.. autofunction:: nni.retiarii.fixed_arch
Citations
---------
.. bibliography::
docs/source/conf.py
View file @
5b7dac5c
...
...
@@ -49,6 +49,7 @@ extensions = [
'sphinx.ext.napoleon'
,
'sphinx.ext.viewcode'
,
'sphinx.ext.intersphinx'
,
'sphinxcontrib.bibtex'
,
# 'nbsphinx', # nbsphinx has conflicts with sphinx-gallery.
'sphinx.ext.extlinks'
,
'IPython.sphinxext.ipython_console_highlighting'
,
...
...
@@ -62,6 +63,9 @@ extensions = [
# Add mock modules
autodoc_mock_imports
=
[
'apex'
,
'nni_node'
,
'tensorrt'
,
'pycuda'
,
'nn_meter'
]
# Bibliography files
bibtex_bibfiles
=
[
'refs.bib'
]
# Sphinx gallery examples
sphinx_gallery_conf
=
{
'examples_dirs'
:
'../../examples/tutorials'
,
# path to your example scripts
...
...
docs/source/refs.bib
0 → 100644
View file @
5b7dac5c
@inproceedings
{
liu2018darts
,
title
=
{DARTS: Differentiable Architecture Search}
,
author
=
{Liu, Hanxiao and Simonyan, Karen and Yang, Yiming}
,
booktitle
=
{International Conference on Learning Representations}
,
year
=
{2018}
}
@inproceedings
{
cai2018proxylessnas
,
title
=
{ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware}
,
author
=
{Cai, Han and Zhu, Ligeng and Han, Song}
,
booktitle
=
{International Conference on Learning Representations}
,
year
=
{2018}
}
@inproceedings
{
xie2018snas
,
title
=
{SNAS: stochastic neural architecture search}
,
author
=
{Xie, Sirui and Zheng, Hehui and Liu, Chunxiao and Lin, Liang}
,
booktitle
=
{International Conference on Learning Representations}
,
year
=
{2018}
}
@inproceedings
{
pham2018efficient
,
title
=
{Efficient neural architecture search via parameters sharing}
,
author
=
{Pham, Hieu and Guan, Melody and Zoph, Barret and Le, Quoc and Dean, Jeff}
,
booktitle
=
{International conference on machine learning}
,
pages
=
{4095--4104}
,
year
=
{2018}
,
organization
=
{PMLR}
}
docs/static/css/material_custom.css
View file @
5b7dac5c
...
...
@@ -47,3 +47,8 @@ nav.md-tabs .md-tabs__item:not(:last-child) .md-tabs__link:after {
.md-nav
span
.caption
{
margin-top
:
1.25em
;
}
/* citation style */
.citation
dt
{
padding-right
:
1em
;
}
nni/retiarii/experiment/pytorch.py
View file @
5b7dac5c
...
...
@@ -34,7 +34,9 @@ from ..execution.utils import get_mutation_dict
from
..graph
import
Evaluator
from
..integration
import
RetiariiAdvisor
from
..mutator
import
Mutator
from
..nn.pytorch.mutator
import
extract_mutation_from_pt_module
,
process_inline_mutation
,
process_evaluator_mutations
from
..nn.pytorch.mutator
import
(
extract_mutation_from_pt_module
,
process_inline_mutation
,
process_evaluator_mutations
,
process_oneshot_mutations
)
from
..oneshot.interface
import
BaseOneShotTrainer
from
..serializer
import
is_model_wrapped
from
..strategy
import
BaseStrategy
...
...
@@ -86,7 +88,7 @@ class RetiariiExeConfig(ConfigBase):
if
key
==
'trial_code_directory'
and
not
(
str
(
value
)
==
'.'
or
os
.
path
.
isabs
(
value
)):
raise
AttributeError
(
f
'
{
key
}
is not supposed to be set in Retiarii mode by users!'
)
if
key
==
'execution_engine'
:
assert
value
in
[
'base'
,
'py'
,
'cgo'
,
'benchmark'
],
f
'The specified execution engine "
{
value
}
" is not supported.'
assert
value
in
[
'base'
,
'py'
,
'cgo'
,
'benchmark'
,
'oneshot'
],
f
'The specified execution engine "
{
value
}
" is not supported.'
self
.
__dict__
[
'trial_command'
]
=
'python3 -m nni.retiarii.trial_entry '
+
value
self
.
__dict__
[
key
]
=
value
...
...
@@ -115,9 +117,11 @@ _validation_rules = {
}
def
preprocess_model
(
base_model
,
traine
r
,
applied_mutators
,
full_ir
=
True
,
dummy_input
=
None
):
def
preprocess_model
(
base_model
,
evaluato
r
,
applied_mutators
,
full_ir
=
True
,
dummy_input
=
None
,
oneshot
=
False
):
# TODO: this logic might need to be refactored into execution engine
if
full_ir
:
if
oneshot
:
base_model_ir
,
mutators
=
process_oneshot_mutations
(
base_model
,
evaluator
)
elif
full_ir
:
try
:
script_module
=
torch
.
jit
.
script
(
base_model
)
except
Exception
as
e
:
...
...
@@ -134,7 +138,7 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_
mutators
=
process_inline_mutation
(
base_model_ir
)
else
:
base_model_ir
,
mutators
=
extract_mutation_from_pt_module
(
base_model
)
base_model_ir
.
evaluator
=
traine
r
base_model_ir
.
evaluator
=
evaluato
r
if
mutators
is
not
None
and
applied_mutators
:
raise
RuntimeError
(
'Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
...
...
@@ -144,7 +148,7 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_
return
base_model_ir
,
applied_mutators
def
debug_mutated_model
(
base_model
,
traine
r
,
applied_mutators
):
def
debug_mutated_model
(
base_model
,
evaluato
r
,
applied_mutators
):
"""
Locally run only one trial without launching an experiment for debug purpose, then exit.
For example, it can be used to quickly check shape mismatch.
...
...
@@ -152,16 +156,18 @@ def debug_mutated_model(base_model, trainer, applied_mutators):
Specifically, it applies mutators (default to choose the first candidate for the choices)
to generate a new model, then run this model locally.
The model will be parsed with graph execution engine.
Parameters
----------
base_model : nni.retiarii.nn.pytorch.nn.Module
the base model
traine
r : nni.retiarii.
e
valuator
evaluato
r : nni.retiarii.
graph.E
valuator
the training class of the generated models
applied_mutators : list
a list of mutators that will be applied on the base model for generating a new model
"""
base_model_ir
,
applied_mutators
=
preprocess_model
(
base_model
,
traine
r
,
applied_mutators
)
base_model_ir
,
applied_mutators
=
preprocess_model
(
base_model
,
evaluato
r
,
applied_mutators
)
from
..strategy
import
_LocalDebugStrategy
strategy
=
_LocalDebugStrategy
()
strategy
.
run
(
base_model_ir
,
applied_mutators
)
...
...
@@ -169,21 +175,99 @@ def debug_mutated_model(base_model, trainer, applied_mutators):
class
RetiariiExperiment
(
Experiment
):
def
__init__
(
self
,
base_model
:
nn
.
Module
,
trainer
:
Union
[
Evaluator
,
BaseOneShotTrainer
],
applied_mutators
:
List
[
Mutator
]
=
None
,
strategy
:
BaseStrategy
=
None
):
"""
The entry for a NAS experiment.
Users can use this class to start/stop or inspect an experiment, like exporting the results.
Experiment is a sub-class of :class:`nni.experiment.Experiment`, there are many similarities such as
configurable training service to distributed running the experiment on remote server.
But unlike :class:`nni.experiment.Experiment`, RetiariiExperiment doesn't support configure:
- ``trial_code_directory``, which can only be current working directory.
- ``search_space``, which is auto-generated in NAS.
- ``trial_command``, which must be ``python -m nni.retiarii.trial_entry`` to launch the modulized trial code.
RetiariiExperiment also doesn't have tuner/assessor/advisor, because they are also implemented in strategy.
Also, unlike :class:`nni.experiment.Experiment` which is bounded to a node server,
RetiariiExperiment optionally starts a node server to schedule the trials, when the strategy is a multi-trial strategy.
When the strategy is one-shot, the step of launching node server is omitted, and the experiment is run locally by default.
Configurations of experiments, such as execution engine, number of GPUs allocated,
should be put into a :class:`RetiariiExeConfig` and used as an argument of :meth:`RetiariiExperiment.run`.
Parameters
----------
base_model : nn.Module
The model defining the search space / base skeleton without mutation.
It should be wrapped by decorator ``nni.retiarii.model_wrapper``.
evaluator : nni.retiarii.Evaluator, default = None
Evaluator for the experiment.
If you are using a one-shot trainer, it should be placed here, although this usage is deprecated.
applied_mutators : list of nni.retiarii.Mutator, default = None
Mutators os mutate the base model. If none, mutators are skipped.
Note that when ``base_model`` uses inline mutations (e.g., LayerChoice), ``applied_mutators`` must be empty / none.
strategy : nni.retiarii.strategy.BaseStrategy, default = None
Exploration strategy. Can be multi-trial or one-shot.
trainer : BaseOneShotTrainer
Kept for compatibility purposes.
Examples
--------
Multi-trial NAS:
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
>>> exp = RetiariiExperiment(base_model, model_evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig('local')
>>> exp_config.trial_concurrency = 2
>>> exp_config.max_trial_number = 20
>>> exp_config.training_service.use_active_gpu = False
>>> exp.run(exp_config, 8081)
One-shot NAS:
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
>>> exp = RetiariiExperiment(base_model, evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig()
>>> exp_config.execution_engine = 'oneshot' # must be set of one-shot strategy
>>> exp.run(exp_config)
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
"""
def
__init__
(
self
,
base_model
:
nn
.
Module
,
evaluator
:
Union
[
BaseOneShotTrainer
,
Evaluator
]
=
None
,
applied_mutators
:
List
[
Mutator
]
=
None
,
strategy
:
BaseStrategy
=
None
,
trainer
:
BaseOneShotTrainer
=
None
):
if
trainer
is
not
None
:
warnings
.
warn
(
'Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.'
,
DeprecationWarning
)
evaluator
=
trainer
if
evaluator
is
None
:
raise
ValueError
(
'Evaluator should not be none.'
)
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self
.
config
:
RetiariiExeConfig
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
base_model
=
base_model
self
.
trainer
=
traine
r
self
.
evaluator
:
Evaluator
=
evaluato
r
self
.
applied_mutators
=
applied_mutators
self
.
strategy
=
strategy
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
# FIXME: this is only a workaround
from
nni.retiarii.oneshot.pytorch.strategy
import
OneShotStrategy
if
not
isinstance
(
strategy
,
OneShotStrategy
):
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
url_prefix
=
None
...
...
@@ -196,11 +280,11 @@ class RetiariiExperiment(Experiment):
def
_start_strategy
(
self
):
base_model_ir
,
self
.
applied_mutators
=
preprocess_model
(
self
.
base_model
,
self
.
traine
r
,
self
.
applied_mutators
,
self
.
base_model
,
self
.
evaluato
r
,
self
.
applied_mutators
,
full_ir
=
self
.
config
.
execution_engine
not
in
[
'py'
,
'benchmark'
],
dummy_input
=
self
.
config
.
dummy_input
)
self
.
applied_mutators
+=
process_evaluator_mutations
(
self
.
traine
r
,
self
.
applied_mutators
)
self
.
applied_mutators
+=
process_evaluator_mutations
(
self
.
evaluato
r
,
self
.
applied_mutators
)
_logger
.
info
(
'Start strategy...'
)
search_space
=
dry_run_for_formatted_search_space
(
base_model_ir
,
self
.
applied_mutators
)
...
...
@@ -302,8 +386,23 @@ class RetiariiExperiment(Experiment):
Run the experiment.
This function will block until experiment finish or error.
"""
if
isinstance
(
self
.
trainer
,
BaseOneShotTrainer
):
self
.
trainer
.
fit
()
if
isinstance
(
self
.
evaluator
,
BaseOneShotTrainer
):
# TODO: will throw a deprecation warning soon
# warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
# 'We will try to convert this trainer to our new implementation to run the algorithm. '
# 'In case you want to stick to the old implementation, '
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self
.
evaluator
.
fit
()
if
config
is
None
:
warnings
.
warn
(
'config = None is deprecate in future. If you are running a one-shot experiment, '
'please consider creating a config and set execution engine to `oneshot`.'
,
DeprecationWarning
)
config
=
RetiariiExeConfig
()
config
.
execution_engine
=
'oneshot'
if
config
.
execution_engine
==
'oneshot'
:
base_model_ir
,
self
.
applied_mutators
=
preprocess_model
(
self
.
base_model
,
self
.
evaluator
,
self
.
applied_mutators
,
oneshot
=
True
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
else
:
assert
config
is
not
None
,
'You are using classic search mode, config cannot be None!'
self
.
config
=
config
...
...
@@ -388,10 +487,14 @@ class RetiariiExperiment(Experiment):
"""
if
formatter
==
'code'
:
assert
self
.
config
.
execution_engine
!=
'py'
,
'You should use `dict` formatter when using Python execution engine.'
if
isinstance
(
self
.
traine
r
,
BaseOneShotTrainer
):
if
isinstance
(
self
.
evaluato
r
,
BaseOneShotTrainer
):
assert
top_k
==
1
,
'Only support top_k is 1 for now.'
return
self
.
trainer
.
export
()
else
:
return
self
.
evaluator
.
export
()
try
:
# this currently works for one-shot algorithms
return
self
.
strategy
.
export_top_models
(
top_k
=
top_k
)
except
NotImplementedError
:
# when strategy hasn't implemented its own export logic
all_models
=
filter
(
lambda
m
:
m
.
metric
is
not
None
,
list_models
())
assert
optimize_mode
in
[
'maximize'
,
'minimize'
]
all_models
=
sorted
(
all_models
,
key
=
lambda
m
:
m
.
metric
,
reverse
=
optimize_mode
==
'maximize'
)
...
...
nni/retiarii/graph.py
View file @
5b7dac5c
...
...
@@ -84,6 +84,8 @@ class Model:
Attributes
----------
python_object
Python object of base model. It will be none when the base model is not available.
python_class
Python class that base model is converted from.
python_init_params
...
...
@@ -110,6 +112,7 @@ class Model:
def
__init__
(
self
,
_internal
=
False
):
assert
_internal
,
'`Model()` is private, use `model.fork()` instead'
self
.
model_id
:
int
=
uid
(
'model'
)
self
.
python_object
:
Optional
[
Any
]
=
None
# type is uncertain because it could differ between DL frameworks
self
.
python_class
:
Optional
[
Type
]
=
None
self
.
python_init_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
5b7dac5c
...
...
@@ -409,6 +409,20 @@ def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mu
return
mutators
# the following are written for one-shot mode
# they shouldn't technically belong here, but all other engines are written here
# let's refactor later
def
process_oneshot_mutations
(
base_model
:
nn
.
Module
,
evaluator
:
Evaluator
):
# It's not intuitive, at all, (actually very hacky) to wrap a `base_model` and `evaluator` into a graph.Model.
# But unfortunately, this is the required interface of strategy.
model
=
Model
(
_internal
=
True
)
model
.
python_object
=
base_model
# no need to set evaluator here because it will be set after this method is called
return
model
,
[]
# utility functions
...
...
nni/retiarii/oneshot/pytorch/__init__.py
View file @
5b7dac5c
...
...
@@ -5,6 +5,6 @@ from .darts import DartsTrainer
from
.enas
import
EnasTrainer
from
.proxyless
import
ProxylessTrainer
from
.random
import
SinglePathTrainer
,
RandomTrainer
from
.differentiable
import
DartsModule
,
ProxylessModule
,
S
NAS
Module
from
.sampling
import
EnasModule
,
RandomSampl
e
Module
from
.differentiable
import
DartsModule
,
ProxylessModule
,
S
nas
Module
from
.sampling
import
EnasModule
,
RandomSampl
ing
Module
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
5b7dac5c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Dict
,
Type
,
Callable
,
List
,
Optional
import
pytorch_lightning
as
pl
import
torch.optim
as
optim
import
torch.nn
as
nn
from
torch.optim.lr_scheduler
import
_LRScheduler
ReplaceDictType
=
Dict
[
Type
[
nn
.
Module
],
Callable
[[
nn
.
Module
],
nn
.
Module
]]
def
_replace_module_with_type
(
root_module
,
replace_dict
,
modules
):
def
_replace_module_with_type
(
root_module
:
nn
.
Module
,
replace_dict
:
ReplaceDictType
,
modules
:
List
[
nn
.
Module
]
):
"""
Replace xxxChoice in user's model with NAS modules.
...
...
@@ -45,31 +49,50 @@ def _replace_module_with_type(root_module, replace_dict, modules):
class
BaseOneShotLightningModule
(
pl
.
LightningModule
):
_custom_replace_dict_note
=
"""custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be ``xxxChoice`` type.
Values should callable accepting an ``nn.Module`` and returning an ``nn.Module``.
This custom replace dict will override the default replace dict of each NAS method.
"""
The base class for all one-shot NAS modules. Essential function such as preprocessing user's model, redirecting lightning
hooks for user's model, configuring optimizers and exporting NAS result are implemented in this class.
_inner_module_note
=
"""inner_module : pytorch_lightning.LightningModule
It's a `LightningModule <https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html>`__
that defines computations, train/val loops, optimizers in a single class.
When used in NNI, the ``inner_module`` is the combination of instances of evaluator + base model
(to be precise, a base model wrapped with LightningModule in evaluator).
"""
__doc__
=
"""
The base class for all one-shot NAS modules.
In NNI, we try to separate the "search" part and "training" part in one-shot NAS.
The "training" part is defined with evaluator interface (has to be lightning evaluator interface to work with oneshot).
Since the lightning evaluator has already broken down the training into minimal building blocks,
we can re-assemble them after combining them with the "search" part of a particular algorithm.
After the re-assembling, this module has defined all the search + training. The experiment can use a lightning trainer
(which is another part in the evaluator) to train this module, so as to complete the search process.
Essential function such as preprocessing user's model, redirecting lightning hooks for user's model,
configuring optimizers and exporting NAS result are implemented in this class.
Attributes
----------
nas_modules : List[nn.Module]
The replace result of a specific NAS method.
xxxChoice will be replaced with some other modules with respect to the
NAS method.
The replace result of a specific NAS method.
xxxChoice will be replaced with some other modules with respect to the
NAS method.
Parameters
----------
base_model : pl.LightningModule
The evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
"""
"""
+
_inner_module_note
+
_custom_replace_dict_note
automatic_optimization
=
False
def
__init__
(
self
,
base_model
,
custom_replace_dict
=
None
):
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
custom_replace_dict
:
Optional
[
ReplaceDictType
]
=
None
):
super
().
__init__
()
assert
isinstance
(
base
_mode
l
,
pl
.
LightningModule
)
self
.
model
=
base
_mode
l
assert
isinstance
(
inner
_mod
ul
e
,
pl
.
LightningModule
)
self
.
model
=
inner
_mod
ul
e
# replace xxxChoice with respect to NAS alg
# replaced modules are stored in self.nas_modules
...
...
@@ -85,16 +108,18 @@ class BaseOneShotLightningModule(pl.LightningModule):
return
self
.
model
(
x
)
def
training_step
(
self
,
batch
,
batch_idx
):
# You can use self.architecture_optimizers or self.user_optimizers to get optimizers in
# your own training step.
"""This is the implementation of what happens in training loops of one-shot algos.
It usually calls ``self.model.training_step`` which implements the real training recipe of the users' model.
"""
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
def
configure_optimizers
(
self
):
"""
Combine architecture optimizers and user's model optimizers.
You can overwrite configure_architecture_optimizers if architecture optimizers are needed in your NAS algorithm.
By now ``self.model`` is currently a :class:`nni.retiarii.evaluator.pytorch.lightning._SupervisedLearningModule`
and it only returns 1 optimizer. But for extendibility, codes for other return value types are also implemented.
For now ``self.model`` is tested against :class:`nni.retiarii.evaluator.pytorch.lightning._SupervisedLearningModule`
and it only returns 1 optimizer.
But for extendibility, codes for other return value types are also implemented.
"""
# pylint: disable=assignment-from-none
arc_optimizers
=
self
.
configure_architecture_optimizers
()
...
...
@@ -178,8 +203,8 @@ class BaseOneShotLightningModule(pl.LightningModule):
@
property
def
default_replace_dict
(
self
):
"""
Default xxxChoice replace dict. This is called in ``__init__`` to get the default replace functions for your NAS algorithm.
Note that your default replace functions may be overridden by user-defined custom_replace_dict.
Default
``
xxxChoice
``
replace dict. This is called in ``__init__`` to get the default replace functions for your NAS algorithm.
Note that your default replace functions may be overridden by user-defined
``
custom_replace_dict
``
.
Returns
----------
...
...
nni/retiarii/oneshot/pytorch/differentiable.py
View file @
5b7dac5c
...
...
@@ -2,12 +2,14 @@
# Licensed under the MIT license.
from
collections
import
OrderedDict
from
typing
import
Optional
import
pytorch_lightning
as
pl
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
.base_lightning
import
BaseOneShotLightningModule
from
.base_lightning
import
BaseOneShotLightningModule
,
ReplaceDictType
class
DartsLayerChoice
(
nn
.
Module
):
...
...
@@ -63,17 +65,35 @@ class DartsInputChoice(nn.Module):
class
DartsModule
(
BaseOneShotLightningModule
):
"""
The DARTS module. Each iteration consists of 2 training phases. The phase 1 is architecture step, in which model parameters are
frozen and the architecture parameters are trained. The phase 2 is model step, in which architecture parameters are frozen and
model parameters are trained. See [darts] for details.
The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.
_darts_note
=
"""
DARTS :cite:p:`liu2018darts` algorithm is one of the most fundamental one-shot algorithm.
Reference
DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
The current implementation is for DARTS in first order. Second order (unrolled) is not supported yet.
{{module_notes}}
Parameters
----------
.. [darts] H. Liu, K. Simonyan, and Y. Yang, “DARTS: Differentiable Architecture Search,” presented at the
International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=S1eYHoC5FX
"""
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_custom_replace_dict_note
)
__doc__
=
_darts_note
.
format
(
module_notes
=
'The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
custom_replace_dict
:
Optional
[
ReplaceDictType
]
=
None
,
arc_learning_rate
:
float
=
3.0E-4
):
super
().
__init__
(
inner_module
,
custom_replace_dict
=
custom_replace_dict
)
self
.
arc_learning_rate
=
arc_learning_rate
def
training_step
(
self
,
batch
,
batch_idx
):
# grad manually
...
...
@@ -118,8 +138,8 @@ class DartsModule(BaseOneShotLightningModule):
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
DartsLayerChoice
,
InputChoice
:
DartsInputChoice
LayerChoice
:
DartsLayerChoice
,
InputChoice
:
DartsInputChoice
}
def
configure_architecture_optimizers
(
self
):
...
...
@@ -132,7 +152,7 @@ class DartsModule(BaseOneShotLightningModule):
else
:
ctrl_params
[
m
.
name
]
=
m
.
alpha
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
ctrl_params
.
values
()),
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
weight_decay
=
1.0E-3
)
return
ctrl_optim
...
...
@@ -279,28 +299,34 @@ class ProxylessInputChoice(nn.Module):
class
ProxylessModule
(
DartsModule
):
"""
The Proxyless Module. This is a darts-based method that resamples the architecture to reduce memory consumption.
The Proxyless Module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.
_proxyless_note
=
"""
Implementation of ProxylessNAS :cite:p:`cai2018proxylessnas`.
It's a DARTS-based method that resamples the architecture to reduce memory consumption.
Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path.
Reference
{{module_notes}}
Parameters
----------
.. [proxyless] H. Cai, L. Zhu, and S. Han, “ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware,” presented
at the International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=HylVB3AqYm
"""
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_custom_replace_dict_note
)
__doc__
=
_proxyless_note
.
format
(
module_notes
=
'This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
ProxylessLayerChoice
,
InputChoice
:
ProxylessInputChoice
LayerChoice
:
ProxylessLayerChoice
,
InputChoice
:
ProxylessInputChoice
}
def
configure_architecture_optimizers
(
self
):
ctrl_optim
=
torch
.
optim
.
Adam
([
m
.
alpha
for
_
,
m
in
self
.
nas_modules
],
3.e-4
,
weight_decay
=
0
,
betas
=
(
0
,
0.999
),
eps
=
1e-8
)
return
ctrl_optim
def
_resample
(
self
):
for
_
,
m
in
self
.
nas_modules
:
m
.
resample
()
...
...
@@ -312,52 +338,60 @@ class ProxylessModule(DartsModule):
class
SNASLayerChoice
(
DartsLayerChoice
):
def
forward
(
self
,
*
args
,
**
kwargs
):
self
.
one_hot
=
F
.
gumbel_softmax
(
self
.
alpha
,
self
.
temp
)
one_hot
=
F
.
gumbel_softmax
(
self
.
alpha
,
self
.
temp
)
op_results
=
torch
.
stack
([
op
(
*
args
,
**
kwargs
)
for
op
in
self
.
op_choices
.
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
yhat
=
torch
.
sum
(
op_results
*
self
.
one_hot
.
view
(
*
alpha_shape
),
0
)
yhat
=
torch
.
sum
(
op_results
*
one_hot
.
view
(
*
alpha_shape
),
0
)
return
yhat
class
SNASInputChoice
(
DartsInputChoice
):
def
forward
(
self
,
inputs
):
self
.
one_hot
=
F
.
gumbel_softmax
(
self
.
alpha
,
self
.
temp
)
one_hot
=
F
.
gumbel_softmax
(
self
.
alpha
,
self
.
temp
)
inputs
=
torch
.
stack
(
inputs
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
yhat
=
torch
.
sum
(
inputs
*
self
.
one_hot
.
view
(
*
alpha_shape
),
0
)
yhat
=
torch
.
sum
(
inputs
*
one_hot
.
view
(
*
alpha_shape
),
0
)
return
yhat
class
SNASModule
(
DartsModule
):
"""
The SNAS Module. This is a darts-based method that uses gumble-softmax to simulate one-hot distribution.
The SNAS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.
class
SnasModule
(
DartsModule
):
_snas_note
=
"""
Implementation of SNAS :cite:p:`xie2018snas`.
It's a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path.
{{module_notes}}
Parameters
----------
base_model : pl.LightningModule
The evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
gumble_temperature : float
The initial temperature used in gumble-softmax.
{{module_params}}
{base_params}
gumbel_temperature : float
The initial temperature used in gumbel-softmax.
use_temp_anneal : bool
True: a linear annealing will be applied to gumble_temperature. False: run at a fixed temperature. See [snas] for details.
If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See :cite:t:`xie2018snas` for details.
min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
Reference
----------
.. [snas] S. Xie, H. Zheng, C. Liu, and L. Lin, “SNAS: stochastic neural architecture search,” presented at the
International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=rylqooRqK7
"""
def
__init__
(
self
,
base_model
,
gumble_temperature
=
1.
,
use_temp_anneal
=
False
,
min_temp
=
.
33
,
custom_replace_dict
=
None
):
super
().
__init__
(
base_model
,
custom_replace_dict
)
self
.
temp
=
gumble_temperature
self
.
init_temp
=
gumble_temperature
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_custom_replace_dict_note
)
__doc__
=
_snas_note
.
format
(
module_notes
=
'This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
def
__init__
(
self
,
inner_module
,
custom_replace_dict
:
Optional
[
ReplaceDictType
]
=
None
,
arc_learning_rate
:
float
=
3.0e-4
,
gumbel_temperature
:
float
=
1.
,
use_temp_anneal
:
bool
=
False
,
min_temp
:
float
=
.
33
):
super
().
__init__
(
inner_module
,
custom_replace_dict
,
arc_learning_rate
=
arc_learning_rate
)
self
.
temp
=
gumbel_temperature
self
.
init_temp
=
gumbel_temperature
self
.
use_temp_anneal
=
use_temp_anneal
self
.
min_temp
=
min_temp
...
...
@@ -366,14 +400,14 @@ class SNASModule(DartsModule):
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
for
_
,
nas_module
in
self
.
nas_modules
:
nas_module
.
temp
=
self
.
temp
for
_
,
nas_module
in
self
.
nas_modules
:
nas_module
.
temp
=
self
.
temp
return
self
.
model
.
on_epoch_start
()
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
SNASLayerChoice
,
InputChoice
:
SNASInputChoice
LayerChoice
:
SNASLayerChoice
,
InputChoice
:
SNASInputChoice
}
nni/retiarii/oneshot/pytorch/sampling.py
View file @
5b7dac5c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Dict
,
Any
,
Optional
import
random
import
pytorch_lightning
as
pl
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
nni.retiarii.nn.pytorch.api
import
LayerChoice
,
InputChoice
from
.random
import
PathSamplingLayerChoice
,
PathSamplingInputChoice
from
.base_lightning
import
BaseOneShotLightningModule
from
.base_lightning
import
BaseOneShotLightningModule
,
ReplaceDictType
from
.enas
import
ReinforceController
,
ReinforceField
class
EnasModule
(
BaseOneShotLightningModule
):
"""
The ENAS module. There are 2 steps in an epoch. 1: training model parameters. 2: training ENAS RL agent. The agent will produce
a sample of model architecture to get the best reward.
The ENASModule should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.
_enas_note
=
"""
The implementation of ENAS :cite:p:`pham2018efficient`. There are 2 steps in an epoch.
Firstly, training model parameters.
Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward.
{{module_notes}}
Parameters
----------
base_model : pl.LightningModule
he evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
{{module_params}}
{base_params}
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
entropy_weight : float
...
...
@@ -33,22 +37,25 @@ class EnasModule(BaseOneShotLightningModule):
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
grad_clip : float
Gradient clipping value.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
Reference
----------
.. [enas] H. Pham, M. Guan, B. Zoph, Q. Le, and J. Dean, “Efficient Neural Architecture Search via Parameters Sharing,”
in Proceedings of the 35th International Conference on Machine Learning, Jul. 2018, pp. 4095-4104.
Available: https://proceedings.mlr.press/v80/pham18a.html
"""
def
__init__
(
self
,
base_model
,
ctrl_kwargs
=
None
,
entropy_weight
=
1e-4
,
skip_weight
=
.
8
,
baseline_decay
=
.
999
,
ctrl_steps_aggregate
=
20
,
grad_clip
=
0
,
custom_replace_dict
=
None
):
super
().
__init__
(
base_model
,
custom_replace_dict
)
ctrl_grad_clip : float
Gradient clipping value of controller.
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_custom_replace_dict_note
)
__doc__
=
_enas_note
.
format
(
module_notes
=
'``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
ctrl_kwargs
:
Dict
[
str
,
Any
]
=
None
,
entropy_weight
:
float
=
1e-4
,
skip_weight
:
float
=
.
8
,
baseline_decay
:
float
=
.
999
,
ctrl_steps_aggregate
:
float
=
20
,
ctrl_grad_clip
:
float
=
0
,
custom_replace_dict
:
Optional
[
ReplaceDictType
]
=
None
):
super
().
__init__
(
inner_module
,
custom_replace_dict
)
self
.
nas_fields
=
[
ReinforceField
(
name
,
len
(
module
),
isinstance
(
module
,
PathSamplingLayerChoice
)
or
module
.
n_chosen
==
1
)
...
...
@@ -60,7 +67,7 @@ class EnasModule(BaseOneShotLightningModule):
self
.
baseline_decay
=
baseline_decay
self
.
baseline
=
0.
self
.
ctrl_steps_aggregate
=
ctrl_steps_aggregate
self
.
grad_clip
=
grad_clip
self
.
ctrl_
grad_clip
=
ctrl_
grad_clip
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
...
...
@@ -116,8 +123,8 @@ class EnasModule(BaseOneShotLightningModule):
self
.
manual_backward
(
rnn_step_loss
)
if
(
batch_idx
+
1
)
%
self
.
ctrl_steps_aggregate
==
0
:
if
self
.
grad_clip
>
0
:
nn
.
utils
.
clip_grad_norm_
(
self
.
controller
.
parameters
(),
self
.
grad_clip
)
if
self
.
ctrl_
grad_clip
>
0
:
nn
.
utils
.
clip_grad_norm_
(
self
.
controller
.
parameters
(),
self
.
ctrl_
grad_clip
)
arc_opt
.
step
()
arc_opt
.
zero_grad
()
...
...
@@ -135,20 +142,22 @@ class EnasModule(BaseOneShotLightningModule):
return
self
.
controller
.
resample
()
class
RandomSampleModule
(
BaseOneShotLightningModule
):
"""
Random Sampling NAS Algorithm. In each epoch, model parameters are trained after a uniformly random sampling of each choice.
The training result is also a random sample of the search space.
class
RandomSamplingModule
(
BaseOneShotLightningModule
):
_random_note
=
"""
Random Sampling NAS Algorithm.
In each epoch, model parameters are trained after a uniformly random sampling of each choice.
Notably, the exporting result is **also a random sample** of the search space.
Parameters
----------
base_model : pl.LightningModule
he evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
"""
{{module_params}}
{base_params}
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_custom_replace_dict_note
)
__doc__
=
_random_note
.
format
(
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
automatic_optimization
=
True
def
training_step
(
self
,
batch
,
batch_idx
):
...
...
nni/retiarii/oneshot/pytorch/strategy.py
0 → 100644
View file @
5b7dac5c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Strategy integration of one-shot.
This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.retiarii.strategy``.
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
"""
import
warnings
from
typing
import
Any
,
List
,
Optional
,
Type
,
Union
,
Tuple
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
nni.retiarii.graph
import
Model
from
nni.retiarii.strategy.base
import
BaseStrategy
from
nni.retiarii.evaluator.pytorch.lightning
import
Lightning
,
LightningModule
from
.base_lightning
import
BaseOneShotLightningModule
from
.differentiable
import
DartsModule
,
ProxylessModule
,
SnasModule
from
.sampling
import
EnasModule
,
RandomSamplingModule
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
class
OneShotStrategy
(
BaseStrategy
):
"""Wrap an one-shot lightning module as a one-shot strategy."""
def
__init__
(
self
,
oneshot_module
:
Type
[
BaseOneShotLightningModule
],
**
kwargs
):
self
.
oneshot_module
=
oneshot_module
self
.
oneshot_kwargs
=
kwargs
self
.
model
:
Optional
[
BaseOneShotLightningModule
]
=
None
def
_get_dataloader
(
self
,
train_dataloader
:
DataLoader
,
val_dataloaders
:
DataLoader
)
\
->
Union
[
DataLoader
,
Tuple
[
DataLoader
,
DataLoader
]]:
"""
One-shot strategy typically requires a customized dataloader.
If only train dataloader is produced, return one dataloader.
Otherwise, return train dataloader and valid loader as a tuple.
"""
raise
NotImplementedError
()
def
run
(
self
,
base_model
:
Model
,
applied_mutators
):
# one-shot strategy doesn't use ``applied_mutators``
# but get the "mutators" on their own
_reason
=
'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
py_model
:
nn
.
Module
=
base_model
.
python_object
if
not
isinstance
(
py_model
,
nn
.
Module
):
raise
TypeError
(
'Model is not a nn.Module. '
+
_reason
)
if
applied_mutators
:
raise
ValueError
(
'Mutator is not empty. '
+
_reason
)
if
not
isinstance
(
base_model
.
evaluator
,
Lightning
):
raise
TypeError
(
'Evaluator needs to be a lightning evaluator to make one-shot strategy work.'
)
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
evaluator_module
.
set_model
(
py_model
)
self
.
model
:
BaseOneShotLightningModule
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
evaluator
:
Lightning
=
base_model
.
evaluator
dataloader
=
self
.
_get_dataloader
(
evaluator
.
train_dataloader
,
evaluator
.
val_dataloaders
)
if
isinstance
(
dataloader
,
tuple
):
dataloader
,
val_loader
=
dataloader
evaluator
.
trainer
.
fit
(
self
.
model
,
dataloader
,
val_loader
)
else
:
evaluator
.
trainer
.
fit
(
self
.
model
,
dataloader
)
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
List
[
Any
]:
if
self
.
model
is
None
:
raise
RuntimeError
(
'One-shot strategy needs to be run before export.'
)
if
top_k
!=
1
:
warnings
.
warn
(
'One-shot strategy currently only supports exporting top-1 model.'
,
RuntimeWarning
)
return
[
self
.
model
.
export
()]
class
DARTS
(
OneShotStrategy
):
__doc__
=
DartsModule
.
_darts_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
DartsModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
class
Proxyless
(
OneShotStrategy
):
__doc__
=
ProxylessModule
.
_proxyless_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
EnasModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
class
SNAS
(
OneShotStrategy
):
__doc__
=
SnasModule
.
_snas_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
SnasModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
class
ENAS
(
OneShotStrategy
):
__doc__
=
EnasModule
.
_enas_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
EnasModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
ConcatenateTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
class
RandomOneShot
(
OneShotStrategy
):
__doc__
=
RandomSamplingModule
.
_random_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
RandomSamplingModule
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
train_dataloader
,
val_dataloaders
nni/retiarii/strategy/__init__.py
View file @
5b7dac5c
...
...
@@ -7,3 +7,4 @@ from .evolution import RegularizedEvolution
from
.tpe_strategy
import
TPEStrategy
from
.local_debug_strategy
import
_LocalDebugStrategy
from
.rl
import
PolicyBasedRL
from
.oneshot
import
DARTS
,
Proxyless
,
SNAS
,
ENAS
,
RandomOneShot
nni/retiarii/strategy/base.py
View file @
5b7dac5c
...
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import
abc
from
typing
import
List
from
typing
import
List
,
Any
from
..graph
import
Model
from
..mutator
import
Mutator
...
...
@@ -13,3 +13,6 @@ class BaseStrategy(abc.ABC):
@
abc
.
abstractmethod
def
run
(
self
,
base_model
:
Model
,
applied_mutators
:
List
[
Mutator
])
->
None
:
pass
def
export_top_models
(
self
)
->
List
[
Any
]:
raise
NotImplementedError
(
'"export_top_models" is not implemented.'
)
nni/retiarii/strategy/oneshot.py
0 → 100644
View file @
5b7dac5c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.base
import
BaseStrategy
try
:
from
nni.retiarii.oneshot.pytorch.strategy
import
(
# pylint: disable=unused-import
DARTS
,
SNAS
,
Proxyless
,
ENAS
,
RandomOneShot
)
except
ImportError
as
import_err
:
_import_err
=
import_err
class
ImportFailedStrategy
(
BaseStrategy
):
def
run
(
self
,
base_model
,
applied_mutators
):
raise
_import_err
# otherwise typing check will pointing to the wrong location
globals
()[
'DARTS'
]
=
ImportFailedStrategy
globals
()[
'SNAS'
]
=
ImportFailedStrategy
globals
()[
'Proxyless'
]
=
ImportFailedStrategy
globals
()[
'ENAS'
]
=
ImportFailedStrategy
globals
()[
'RandomOneShot'
]
=
ImportFailedStrategy
test/ut/retiarii/test_oneshot.py
View file @
5b7dac5c
...
...
@@ -8,12 +8,10 @@ from torchvision import transforms
from
torchvision.datasets
import
MNIST
from
torch.utils.data.sampler
import
RandomSampler
from
nni.retiarii
import
strategy
,
model_wrapper
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
DataLoader
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
nni.retiarii.oneshot.pytorch
import
(
ConcatenateTrainValDataLoader
,
DartsModule
,
EnasModule
,
SNASModule
,
InterleavedTrainValDataLoader
,
ProxylessModule
,
RandomSampleModule
)
class
DepthwiseSeparableConv
(
nn
.
Module
):
...
...
@@ -26,6 +24,7 @@ class DepthwiseSeparableConv(nn.Module):
return
self
.
pointwise
(
self
.
depthwise
(
x
))
@
model_wrapper
class
Net
(
pl
.
LightningModule
):
def
__init__
(
self
):
super
().
__init__
()
...
...
@@ -68,7 +67,6 @@ class Net(pl.LightningModule):
return
output
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
prepare_model_data
():
base_model
=
Net
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
...
...
@@ -86,53 +84,42 @@ def prepare_model_data():
return
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
def
_test_strategy
(
strategy_
):
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
experiment
=
RetiariiExperiment
(
base_model
,
cls
,
strategy
=
strategy_
)
config
=
RetiariiExeConfig
()
config
.
execution_engine
=
'oneshot'
experiment
.
run
(
config
)
assert
isinstance
(
experiment
.
export_top_models
()[
0
],
dict
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_darts
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
darts_model
=
DartsModule
(
cls
.
module
)
para_loader
=
InterleavedTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
darts_model
,
para_loader
)
_test_strategy
(
strategy
.
DARTS
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_proxyless
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
proxyless_model
=
ProxylessModule
(
cls
.
module
)
para_loader
=
InterleavedTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
proxyless_model
,
para_loader
)
_test_strategy
(
strategy
.
Proxyless
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_enas
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
enas_model
=
EnasModule
(
cls
.
module
)
concat_loader
=
ConcatenateTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
enas_model
,
concat_loader
)
_test_strategy
(
strategy
.
ENAS
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_random
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
random_model
=
RandomSampleModule
(
cls
.
module
)
cls
.
trainer
.
fit
(
random_model
,
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
_test_strategy
(
strategy
.
RandomOneShot
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_snas
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
proxyless_model
=
SNASModule
(
cls
.
module
,
1
,
use_temp_anneal
=
True
)
para_loader
=
InterleavedTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
proxyless_model
,
para_loader
)
_test_strategy
(
strategy
.
SNAS
())
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment