Unverified Commit 5b7dac5c authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Wrap one-shot algorithms as strategies (#4571)

parent c13392ab
...@@ -12,4 +12,5 @@ rstcheck ...@@ -12,4 +12,5 @@ rstcheck
sphinx sphinx
sphinx-argparse-nni >= 0.4.0 sphinx-argparse-nni >= 0.4.0
sphinx-gallery sphinx-gallery
sphinxcontrib-bibtex
git+https://github.com/bashtage/sphinx-material.git git+https://github.com/bashtage/sphinx-material.git
...@@ -60,38 +60,12 @@ Evaluators ...@@ -60,38 +60,12 @@ Evaluators
.. autoclass:: nni.retiarii.evaluator.pytorch.lightning.Regression .. autoclass:: nni.retiarii.evaluator.pytorch.lightning.Regression
:members: :members:
Oneshot Trainers
----------------
.. autoclass:: nni.retiarii.oneshot.pytorch.DartsTrainer
:members:
.. autoclass:: nni.retiarii.oneshot.pytorch.EnasTrainer
:members:
.. autoclass:: nni.retiarii.oneshot.pytorch.ProxylessTrainer
:members:
.. autoclass:: nni.retiarii.oneshot.pytorch.SinglePathTrainer
:members:
Exploration Strategies Exploration Strategies
---------------------- ----------------------
.. autoclass:: nni.retiarii.strategy.Random .. automodule:: nni.retiarii.strategy
:members:
.. autoclass:: nni.retiarii.strategy.GridSearch
:members:
.. autoclass:: nni.retiarii.strategy.RegularizedEvolution
:members:
.. autoclass:: nni.retiarii.strategy.TPEStrategy
:members:
.. autoclass:: nni.retiarii.strategy.PolicyBasedRL
:members: :members:
:imported-members:
Retiarii Experiments Retiarii Experiments
-------------------- --------------------
...@@ -111,6 +85,17 @@ CGO Execution ...@@ -111,6 +85,17 @@ CGO Execution
.. autofunction:: nni.retiarii.evaluator.pytorch.cgo.evaluator.Regression .. autofunction:: nni.retiarii.evaluator.pytorch.cgo.evaluator.Regression
One-shot Implementation
-----------------------
.. automodule:: nni.retiarii.oneshot
:members:
:imported-members:
.. automodule:: nni.retiarii.oneshot.pytorch
:members:
:imported-members:
Utilities Utilities
--------- ---------
...@@ -120,4 +105,7 @@ Utilities ...@@ -120,4 +105,7 @@ Utilities
.. autofunction:: nni.retiarii.fixed_arch .. autofunction:: nni.retiarii.fixed_arch
Citations
---------
.. bibliography::
...@@ -49,6 +49,7 @@ extensions = [ ...@@ -49,6 +49,7 @@ extensions = [
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
'sphinx.ext.intersphinx', 'sphinx.ext.intersphinx',
'sphinxcontrib.bibtex',
# 'nbsphinx', # nbsphinx has conflicts with sphinx-gallery. # 'nbsphinx', # nbsphinx has conflicts with sphinx-gallery.
'sphinx.ext.extlinks', 'sphinx.ext.extlinks',
'IPython.sphinxext.ipython_console_highlighting', 'IPython.sphinxext.ipython_console_highlighting',
...@@ -62,6 +63,9 @@ extensions = [ ...@@ -62,6 +63,9 @@ extensions = [
# Add mock modules # Add mock modules
autodoc_mock_imports = ['apex', 'nni_node', 'tensorrt', 'pycuda', 'nn_meter'] autodoc_mock_imports = ['apex', 'nni_node', 'tensorrt', 'pycuda', 'nn_meter']
# Bibliography files
bibtex_bibfiles = ['refs.bib']
# Sphinx gallery examples # Sphinx gallery examples
sphinx_gallery_conf = { sphinx_gallery_conf = {
'examples_dirs': '../../examples/tutorials', # path to your example scripts 'examples_dirs': '../../examples/tutorials', # path to your example scripts
......
@inproceedings{liu2018darts,
title={DARTS: Differentiable Architecture Search},
author={Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
booktitle={International Conference on Learning Representations},
year={2018}
}
@inproceedings{cai2018proxylessnas,
title={ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware},
author={Cai, Han and Zhu, Ligeng and Han, Song},
booktitle={International Conference on Learning Representations},
year={2018}
}
@inproceedings{xie2018snas,
title={SNAS: stochastic neural architecture search},
author={Xie, Sirui and Zheng, Hehui and Liu, Chunxiao and Lin, Liang},
booktitle={International Conference on Learning Representations},
year={2018}
}
@inproceedings{pham2018efficient,
title={Efficient neural architecture search via parameters sharing},
author={Pham, Hieu and Guan, Melody and Zoph, Barret and Le, Quoc and Dean, Jeff},
booktitle={International conference on machine learning},
pages={4095--4104},
year={2018},
organization={PMLR}
}
...@@ -47,3 +47,8 @@ nav.md-tabs .md-tabs__item:not(:last-child) .md-tabs__link:after { ...@@ -47,3 +47,8 @@ nav.md-tabs .md-tabs__item:not(:last-child) .md-tabs__link:after {
.md-nav span.caption { .md-nav span.caption {
margin-top: 1.25em; margin-top: 1.25em;
} }
/* citation style */
.citation dt {
padding-right: 1em;
}
...@@ -34,7 +34,9 @@ from ..execution.utils import get_mutation_dict ...@@ -34,7 +34,9 @@ from ..execution.utils import get_mutation_dict
from ..graph import Evaluator from ..graph import Evaluator
from ..integration import RetiariiAdvisor from ..integration import RetiariiAdvisor
from ..mutator import Mutator from ..mutator import Mutator
from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations from ..nn.pytorch.mutator import (
extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations, process_oneshot_mutations
)
from ..oneshot.interface import BaseOneShotTrainer from ..oneshot.interface import BaseOneShotTrainer
from ..serializer import is_model_wrapped from ..serializer import is_model_wrapped
from ..strategy import BaseStrategy from ..strategy import BaseStrategy
...@@ -86,7 +88,7 @@ class RetiariiExeConfig(ConfigBase): ...@@ -86,7 +88,7 @@ class RetiariiExeConfig(ConfigBase):
if key == 'trial_code_directory' and not (str(value) == '.' or os.path.isabs(value)): if key == 'trial_code_directory' and not (str(value) == '.' or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
if key == 'execution_engine': if key == 'execution_engine':
assert value in ['base', 'py', 'cgo', 'benchmark'], f'The specified execution engine "{value}" is not supported.' assert value in ['base', 'py', 'cgo', 'benchmark', 'oneshot'], f'The specified execution engine "{value}" is not supported.'
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
self.__dict__[key] = value self.__dict__[key] = value
...@@ -115,9 +117,11 @@ _validation_rules = { ...@@ -115,9 +117,11 @@ _validation_rules = {
} }
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_input=None): def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
# TODO: this logic might need to be refactored into execution engine # TODO: this logic might need to be refactored into execution engine
if full_ir: if oneshot:
base_model_ir, mutators = process_oneshot_mutations(base_model, evaluator)
elif full_ir:
try: try:
script_module = torch.jit.script(base_model) script_module = torch.jit.script(base_model)
except Exception as e: except Exception as e:
...@@ -134,7 +138,7 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_ ...@@ -134,7 +138,7 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_
mutators = process_inline_mutation(base_model_ir) mutators = process_inline_mutation(base_model_ir)
else: else:
base_model_ir, mutators = extract_mutation_from_pt_module(base_model) base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
base_model_ir.evaluator = trainer base_model_ir.evaluator = evaluator
if mutators is not None and applied_mutators: if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, ' raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
...@@ -144,7 +148,7 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_ ...@@ -144,7 +148,7 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_
return base_model_ir, applied_mutators return base_model_ir, applied_mutators
def debug_mutated_model(base_model, trainer, applied_mutators): def debug_mutated_model(base_model, evaluator, applied_mutators):
""" """
Locally run only one trial without launching an experiment for debug purpose, then exit. Locally run only one trial without launching an experiment for debug purpose, then exit.
For example, it can be used to quickly check shape mismatch. For example, it can be used to quickly check shape mismatch.
...@@ -152,16 +156,18 @@ def debug_mutated_model(base_model, trainer, applied_mutators): ...@@ -152,16 +156,18 @@ def debug_mutated_model(base_model, trainer, applied_mutators):
Specifically, it applies mutators (default to choose the first candidate for the choices) Specifically, it applies mutators (default to choose the first candidate for the choices)
to generate a new model, then run this model locally. to generate a new model, then run this model locally.
The model will be parsed with graph execution engine.
Parameters Parameters
---------- ----------
base_model : nni.retiarii.nn.pytorch.nn.Module base_model : nni.retiarii.nn.pytorch.nn.Module
the base model the base model
trainer : nni.retiarii.evaluator evaluator : nni.retiarii.graph.Evaluator
the training class of the generated models the training class of the generated models
applied_mutators : list applied_mutators : list
a list of mutators that will be applied on the base model for generating a new model a list of mutators that will be applied on the base model for generating a new model
""" """
base_model_ir, applied_mutators = preprocess_model(base_model, trainer, applied_mutators) base_model_ir, applied_mutators = preprocess_model(base_model, evaluator, applied_mutators)
from ..strategy import _LocalDebugStrategy from ..strategy import _LocalDebugStrategy
strategy = _LocalDebugStrategy() strategy = _LocalDebugStrategy()
strategy.run(base_model_ir, applied_mutators) strategy.run(base_model_ir, applied_mutators)
...@@ -169,21 +175,99 @@ def debug_mutated_model(base_model, trainer, applied_mutators): ...@@ -169,21 +175,99 @@ def debug_mutated_model(base_model, trainer, applied_mutators):
class RetiariiExperiment(Experiment): class RetiariiExperiment(Experiment):
def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotTrainer], """
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None): The entry for a NAS experiment.
Users can use this class to start/stop or inspect an experiment, like exporting the results.
Experiment is a sub-class of :class:`nni.experiment.Experiment`, there are many similarities such as
configurable training service to distributed running the experiment on remote server.
But unlike :class:`nni.experiment.Experiment`, RetiariiExperiment doesn't support configure:
- ``trial_code_directory``, which can only be current working directory.
- ``search_space``, which is auto-generated in NAS.
- ``trial_command``, which must be ``python -m nni.retiarii.trial_entry`` to launch the modulized trial code.
RetiariiExperiment also doesn't have tuner/assessor/advisor, because they are also implemented in strategy.
Also, unlike :class:`nni.experiment.Experiment` which is bounded to a node server,
RetiariiExperiment optionally starts a node server to schedule the trials, when the strategy is a multi-trial strategy.
When the strategy is one-shot, the step of launching node server is omitted, and the experiment is run locally by default.
Configurations of experiments, such as execution engine, number of GPUs allocated,
should be put into a :class:`RetiariiExeConfig` and used as an argument of :meth:`RetiariiExperiment.run`.
Parameters
----------
base_model : nn.Module
The model defining the search space / base skeleton without mutation.
It should be wrapped by decorator ``nni.retiarii.model_wrapper``.
evaluator : nni.retiarii.Evaluator, default = None
Evaluator for the experiment.
If you are using a one-shot trainer, it should be placed here, although this usage is deprecated.
applied_mutators : list of nni.retiarii.Mutator, default = None
Mutators os mutate the base model. If none, mutators are skipped.
Note that when ``base_model`` uses inline mutations (e.g., LayerChoice), ``applied_mutators`` must be empty / none.
strategy : nni.retiarii.strategy.BaseStrategy, default = None
Exploration strategy. Can be multi-trial or one-shot.
trainer : BaseOneShotTrainer
Kept for compatibility purposes.
Examples
--------
Multi-trial NAS:
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
>>> exp = RetiariiExperiment(base_model, model_evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig('local')
>>> exp_config.trial_concurrency = 2
>>> exp_config.max_trial_number = 20
>>> exp_config.training_service.use_active_gpu = False
>>> exp.run(exp_config, 8081)
One-shot NAS:
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
>>> exp = RetiariiExperiment(base_model, evaluator, [], search_strategy)
>>> exp_config = RetiariiExeConfig()
>>> exp_config.execution_engine = 'oneshot' # must be set of one-shot strategy
>>> exp.run(exp_config)
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
"""
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = None,
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None,
trainer: BaseOneShotTrainer = None):
if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
evaluator = trainer
if evaluator is None:
raise ValueError('Evaluator should not be none.')
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed. # TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None self.config: RetiariiExeConfig = None
self.port: Optional[int] = None self.port: Optional[int] = None
self.base_model = base_model self.base_model = base_model
self.trainer = trainer self.evaluator: Evaluator = evaluator
self.applied_mutators = applied_mutators self.applied_mutators = applied_mutators
self.strategy = strategy self.strategy = strategy
self._dispatcher = RetiariiAdvisor() # FIXME: this is only a workaround
self._dispatcher_thread: Optional[Thread] = None from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
self._proc: Optional[Popen] = None if not isinstance(strategy, OneShotStrategy):
self._pipe: Optional[Pipe] = None self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self.url_prefix = None self.url_prefix = None
...@@ -196,11 +280,11 @@ class RetiariiExperiment(Experiment): ...@@ -196,11 +280,11 @@ class RetiariiExperiment(Experiment):
def _start_strategy(self): def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model( base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, self.base_model, self.evaluator, self.applied_mutators,
full_ir=self.config.execution_engine not in ['py', 'benchmark'], full_ir=self.config.execution_engine not in ['py', 'benchmark'],
dummy_input=self.config.dummy_input dummy_input=self.config.dummy_input
) )
self.applied_mutators += process_evaluator_mutations(self.trainer, self.applied_mutators) self.applied_mutators += process_evaluator_mutations(self.evaluator, self.applied_mutators)
_logger.info('Start strategy...') _logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators) search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
...@@ -302,8 +386,23 @@ class RetiariiExperiment(Experiment): ...@@ -302,8 +386,23 @@ class RetiariiExperiment(Experiment):
Run the experiment. Run the experiment.
This function will block until experiment finish or error. This function will block until experiment finish or error.
""" """
if isinstance(self.trainer, BaseOneShotTrainer): if isinstance(self.evaluator, BaseOneShotTrainer):
self.trainer.fit() # TODO: will throw a deprecation warning soon
# warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
# 'We will try to convert this trainer to our new implementation to run the algorithm. '
# 'In case you want to stick to the old implementation, '
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self.evaluator.fit()
if config is None:
warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, '
'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning)
config = RetiariiExeConfig()
config.execution_engine = 'oneshot'
if config.execution_engine == 'oneshot':
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True)
self.strategy.run(base_model_ir, self.applied_mutators)
else: else:
assert config is not None, 'You are using classic search mode, config cannot be None!' assert config is not None, 'You are using classic search mode, config cannot be None!'
self.config = config self.config = config
...@@ -388,10 +487,14 @@ class RetiariiExperiment(Experiment): ...@@ -388,10 +487,14 @@ class RetiariiExperiment(Experiment):
""" """
if formatter == 'code': if formatter == 'code':
assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.' assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.'
if isinstance(self.trainer, BaseOneShotTrainer): if isinstance(self.evaluator, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.' assert top_k == 1, 'Only support top_k is 1 for now.'
return self.trainer.export() return self.evaluator.export()
else: try:
# this currently works for one-shot algorithms
return self.strategy.export_top_models(top_k=top_k)
except NotImplementedError:
# when strategy hasn't implemented its own export logic
all_models = filter(lambda m: m.metric is not None, list_models()) all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize'] assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize') all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize')
......
...@@ -84,6 +84,8 @@ class Model: ...@@ -84,6 +84,8 @@ class Model:
Attributes Attributes
---------- ----------
python_object
Python object of base model. It will be none when the base model is not available.
python_class python_class
Python class that base model is converted from. Python class that base model is converted from.
python_init_params python_init_params
...@@ -110,6 +112,7 @@ class Model: ...@@ -110,6 +112,7 @@ class Model:
def __init__(self, _internal=False): def __init__(self, _internal=False):
assert _internal, '`Model()` is private, use `model.fork()` instead' assert _internal, '`Model()` is private, use `model.fork()` instead'
self.model_id: int = uid('model') self.model_id: int = uid('model')
self.python_object: Optional[Any] = None # type is uncertain because it could differ between DL frameworks
self.python_class: Optional[Type] = None self.python_class: Optional[Type] = None
self.python_init_params: Optional[Dict[str, Any]] = None self.python_init_params: Optional[Dict[str, Any]] = None
......
...@@ -409,6 +409,20 @@ def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mu ...@@ -409,6 +409,20 @@ def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mu
return mutators return mutators
# the following are written for one-shot mode
# they shouldn't technically belong here, but all other engines are written here
# let's refactor later
def process_oneshot_mutations(base_model: nn.Module, evaluator: Evaluator):
# It's not intuitive, at all, (actually very hacky) to wrap a `base_model` and `evaluator` into a graph.Model.
# But unfortunately, this is the required interface of strategy.
model = Model(_internal=True)
model.python_object = base_model
# no need to set evaluator here because it will be set after this method is called
return model, []
# utility functions # utility functions
......
...@@ -5,6 +5,6 @@ from .darts import DartsTrainer ...@@ -5,6 +5,6 @@ from .darts import DartsTrainer
from .enas import EnasTrainer from .enas import EnasTrainer
from .proxyless import ProxylessTrainer from .proxyless import ProxylessTrainer
from .random import SinglePathTrainer, RandomTrainer from .random import SinglePathTrainer, RandomTrainer
from .differentiable import DartsModule, ProxylessModule, SNASModule from .differentiable import DartsModule, ProxylessModule, SnasModule
from .sampling import EnasModule, RandomSampleModule from .sampling import EnasModule, RandomSamplingModule
from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Dict, Type, Callable, List, Optional
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.optim as optim import torch.optim as optim
import torch.nn as nn import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
ReplaceDictType = Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]]
def _replace_module_with_type(root_module, replace_dict, modules): def _replace_module_with_type(root_module: nn.Module, replace_dict: ReplaceDictType, modules: List[nn.Module]):
""" """
Replace xxxChoice in user's model with NAS modules. Replace xxxChoice in user's model with NAS modules.
...@@ -45,31 +49,50 @@ def _replace_module_with_type(root_module, replace_dict, modules): ...@@ -45,31 +49,50 @@ def _replace_module_with_type(root_module, replace_dict, modules):
class BaseOneShotLightningModule(pl.LightningModule): class BaseOneShotLightningModule(pl.LightningModule):
_custom_replace_dict_note = """custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be ``xxxChoice`` type.
Values should callable accepting an ``nn.Module`` and returning an ``nn.Module``.
This custom replace dict will override the default replace dict of each NAS method.
""" """
The base class for all one-shot NAS modules. Essential function such as preprocessing user's model, redirecting lightning
hooks for user's model, configuring optimizers and exporting NAS result are implemented in this class. _inner_module_note = """inner_module : pytorch_lightning.LightningModule
It's a `LightningModule <https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html>`__
that defines computations, train/val loops, optimizers in a single class.
When used in NNI, the ``inner_module`` is the combination of instances of evaluator + base model
(to be precise, a base model wrapped with LightningModule in evaluator).
"""
__doc__ = """
The base class for all one-shot NAS modules.
In NNI, we try to separate the "search" part and "training" part in one-shot NAS.
The "training" part is defined with evaluator interface (has to be lightning evaluator interface to work with oneshot).
Since the lightning evaluator has already broken down the training into minimal building blocks,
we can re-assemble them after combining them with the "search" part of a particular algorithm.
After the re-assembling, this module has defined all the search + training. The experiment can use a lightning trainer
(which is another part in the evaluator) to train this module, so as to complete the search process.
Essential function such as preprocessing user's model, redirecting lightning hooks for user's model,
configuring optimizers and exporting NAS result are implemented in this class.
Attributes Attributes
---------- ----------
nas_modules : List[nn.Module] nas_modules : List[nn.Module]
The replace result of a specific NAS method. xxxChoice will be replaced with some other modules with respect to the The replace result of a specific NAS method.
NAS method. xxxChoice will be replaced with some other modules with respect to the NAS method.
Parameters Parameters
---------- ----------
base_model : pl.LightningModule """ + _inner_module_note + _custom_replace_dict_note
The evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
"""
automatic_optimization = False automatic_optimization = False
def __init__(self, base_model, custom_replace_dict=None): def __init__(self, inner_module: pl.LightningModule, custom_replace_dict: Optional[ReplaceDictType] = None):
super().__init__() super().__init__()
assert isinstance(base_model, pl.LightningModule) assert isinstance(inner_module, pl.LightningModule)
self.model = base_model self.model = inner_module
# replace xxxChoice with respect to NAS alg # replace xxxChoice with respect to NAS alg
# replaced modules are stored in self.nas_modules # replaced modules are stored in self.nas_modules
...@@ -85,16 +108,18 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -85,16 +108,18 @@ class BaseOneShotLightningModule(pl.LightningModule):
return self.model(x) return self.model(x)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# You can use self.architecture_optimizers or self.user_optimizers to get optimizers in """This is the implementation of what happens in training loops of one-shot algos.
# your own training step. It usually calls ``self.model.training_step`` which implements the real training recipe of the users' model.
"""
return self.model.training_step(batch, batch_idx) return self.model.training_step(batch, batch_idx)
def configure_optimizers(self): def configure_optimizers(self):
""" """
Combine architecture optimizers and user's model optimizers. Combine architecture optimizers and user's model optimizers.
You can overwrite configure_architecture_optimizers if architecture optimizers are needed in your NAS algorithm. You can overwrite configure_architecture_optimizers if architecture optimizers are needed in your NAS algorithm.
By now ``self.model`` is currently a :class:`nni.retiarii.evaluator.pytorch.lightning._SupervisedLearningModule` For now ``self.model`` is tested against :class:`nni.retiarii.evaluator.pytorch.lightning._SupervisedLearningModule`
and it only returns 1 optimizer. But for extendibility, codes for other return value types are also implemented. and it only returns 1 optimizer.
But for extendibility, codes for other return value types are also implemented.
""" """
# pylint: disable=assignment-from-none # pylint: disable=assignment-from-none
arc_optimizers = self.configure_architecture_optimizers() arc_optimizers = self.configure_architecture_optimizers()
...@@ -178,8 +203,8 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -178,8 +203,8 @@ class BaseOneShotLightningModule(pl.LightningModule):
@property @property
def default_replace_dict(self): def default_replace_dict(self):
""" """
Default xxxChoice replace dict. This is called in ``__init__`` to get the default replace functions for your NAS algorithm. Default ``xxxChoice`` replace dict. This is called in ``__init__`` to get the default replace functions for your NAS algorithm.
Note that your default replace functions may be overridden by user-defined custom_replace_dict. Note that your default replace functions may be overridden by user-defined ``custom_replace_dict``.
Returns Returns
---------- ----------
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from collections import OrderedDict from collections import OrderedDict
from typing import Optional
import pytorch_lightning as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from .base_lightning import BaseOneShotLightningModule from .base_lightning import BaseOneShotLightningModule, ReplaceDictType
class DartsLayerChoice(nn.Module): class DartsLayerChoice(nn.Module):
...@@ -63,17 +65,35 @@ class DartsInputChoice(nn.Module): ...@@ -63,17 +65,35 @@ class DartsInputChoice(nn.Module):
class DartsModule(BaseOneShotLightningModule): class DartsModule(BaseOneShotLightningModule):
""" _darts_note = """
The DARTS module. Each iteration consists of 2 training phases. The phase 1 is architecture step, in which model parameters are DARTS :cite:p:`liu2018darts` algorithm is one of the most fundamental one-shot algorithm.
frozen and the architecture parameters are trained. The phase 2 is model step, in which architecture parameters are frozen and
model parameters are trained. See [darts] for details.
The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.
Reference DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
The current implementation is for DARTS in first order. Second order (unrolled) is not supported yet.
{{module_notes}}
Parameters
---------- ----------
.. [darts] H. Liu, K. Simonyan, and Y. Yang, “DARTS: Differentiable Architecture Search,” presented at the {{module_params}}
International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=S1eYHoC5FX {base_params}
""" arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._custom_replace_dict_note)
__doc__ = _darts_note.format(
module_notes='The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
def __init__(self, inner_module: pl.LightningModule,
custom_replace_dict: Optional[ReplaceDictType] = None,
arc_learning_rate: float = 3.0E-4):
super().__init__(inner_module, custom_replace_dict=custom_replace_dict)
self.arc_learning_rate = arc_learning_rate
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# grad manually # grad manually
...@@ -118,8 +138,8 @@ class DartsModule(BaseOneShotLightningModule): ...@@ -118,8 +138,8 @@ class DartsModule(BaseOneShotLightningModule):
@property @property
def default_replace_dict(self): def default_replace_dict(self):
return { return {
LayerChoice : DartsLayerChoice, LayerChoice: DartsLayerChoice,
InputChoice : DartsInputChoice InputChoice: DartsInputChoice
} }
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
...@@ -132,7 +152,7 @@ class DartsModule(BaseOneShotLightningModule): ...@@ -132,7 +152,7 @@ class DartsModule(BaseOneShotLightningModule):
else: else:
ctrl_params[m.name] = m.alpha ctrl_params[m.name] = m.alpha
ctrl_optim = torch.optim.Adam(list(ctrl_params.values()), 3.e-4, betas=(0.5, 0.999), ctrl_optim = torch.optim.Adam(list(ctrl_params.values()), 3.e-4, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
return ctrl_optim return ctrl_optim
...@@ -279,28 +299,34 @@ class ProxylessInputChoice(nn.Module): ...@@ -279,28 +299,34 @@ class ProxylessInputChoice(nn.Module):
class ProxylessModule(DartsModule): class ProxylessModule(DartsModule):
""" _proxyless_note = """
The Proxyless Module. This is a darts-based method that resamples the architecture to reduce memory consumption. Implementation of ProxylessNAS :cite:p:`cai2018proxylessnas`.
The Proxyless Module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`. It's a DARTS-based method that resamples the architecture to reduce memory consumption.
Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path.
Reference {{module_notes}}
Parameters
---------- ----------
.. [proxyless] H. Cai, L. Zhu, and S. Han, “ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware,” presented {{module_params}}
at the International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=HylVB3AqYm {base_params}
""" arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._custom_replace_dict_note)
__doc__ = _proxyless_note.format(
module_notes='This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
@property @property
def default_replace_dict(self): def default_replace_dict(self):
return { return {
LayerChoice : ProxylessLayerChoice, LayerChoice: ProxylessLayerChoice,
InputChoice : ProxylessInputChoice InputChoice: ProxylessInputChoice
} }
def configure_architecture_optimizers(self):
ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], 3.e-4,
weight_decay=0, betas=(0, 0.999), eps=1e-8)
return ctrl_optim
def _resample(self): def _resample(self):
for _, m in self.nas_modules: for _, m in self.nas_modules:
m.resample() m.resample()
...@@ -312,52 +338,60 @@ class ProxylessModule(DartsModule): ...@@ -312,52 +338,60 @@ class ProxylessModule(DartsModule):
class SNASLayerChoice(DartsLayerChoice): class SNASLayerChoice(DartsLayerChoice):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
self.one_hot = F.gumbel_softmax(self.alpha, self.temp) one_hot = F.gumbel_softmax(self.alpha, self.temp)
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()]) op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
yhat = torch.sum(op_results * self.one_hot.view(*alpha_shape), 0) yhat = torch.sum(op_results * one_hot.view(*alpha_shape), 0)
return yhat return yhat
class SNASInputChoice(DartsInputChoice): class SNASInputChoice(DartsInputChoice):
def forward(self, inputs): def forward(self, inputs):
self.one_hot = F.gumbel_softmax(self.alpha, self.temp) one_hot = F.gumbel_softmax(self.alpha, self.temp)
inputs = torch.stack(inputs) inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1) alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
yhat = torch.sum(inputs * self.one_hot.view(*alpha_shape), 0) yhat = torch.sum(inputs * one_hot.view(*alpha_shape), 0)
return yhat return yhat
class SNASModule(DartsModule): class SnasModule(DartsModule):
""" _snas_note = """
The SNAS Module. This is a darts-based method that uses gumble-softmax to simulate one-hot distribution. Implementation of SNAS :cite:p:`xie2018snas`.
The SNAS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`. It's a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path.
{{module_notes}}
Parameters Parameters
---------- ----------
base_model : pl.LightningModule {{module_params}}
The evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will {base_params}
be wrapped by this model. gumbel_temperature : float
gumble_temperature : float The initial temperature used in gumbel-softmax.
The initial temperature used in gumble-softmax.
use_temp_anneal : bool use_temp_anneal : bool
True: a linear annealing will be applied to gumble_temperature. False: run at a fixed temperature. See [snas] for details. If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See :cite:t:`xie2018snas` for details.
min_temp : float min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False. The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None arc_learning_rate : float
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom Learning rate for architecture optimizer. Default: 3.0e-4
replace dict will override the default replace dict of each NAS method. """.format(base_params=BaseOneShotLightningModule._custom_replace_dict_note)
Reference __doc__ = _snas_note.format(
---------- module_notes='This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.',
.. [snas] S. Xie, H. Zheng, C. Liu, and L. Lin, “SNAS: stochastic neural architecture search,” presented at the module_params=BaseOneShotLightningModule._inner_module_note,
International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=rylqooRqK7 )
"""
def __init__(self, base_model, gumble_temperature = 1., use_temp_anneal = False, def __init__(self, inner_module,
min_temp = .33, custom_replace_dict=None): custom_replace_dict: Optional[ReplaceDictType] = None,
super().__init__(base_model, custom_replace_dict) arc_learning_rate: float = 3.0e-4,
self.temp = gumble_temperature gumbel_temperature: float = 1.,
self.init_temp = gumble_temperature use_temp_anneal: bool = False,
min_temp: float = .33):
super().__init__(inner_module, custom_replace_dict, arc_learning_rate=arc_learning_rate)
self.temp = gumbel_temperature
self.init_temp = gumbel_temperature
self.use_temp_anneal = use_temp_anneal self.use_temp_anneal = use_temp_anneal
self.min_temp = min_temp self.min_temp = min_temp
...@@ -366,14 +400,14 @@ class SNASModule(DartsModule): ...@@ -366,14 +400,14 @@ class SNASModule(DartsModule):
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp) self.temp = max(self.temp, self.min_temp)
for _, nas_module in self.nas_modules: for _, nas_module in self.nas_modules:
nas_module.temp = self.temp nas_module.temp = self.temp
return self.model.on_epoch_start() return self.model.on_epoch_start()
@property @property
def default_replace_dict(self): def default_replace_dict(self):
return { return {
LayerChoice : SNASLayerChoice, LayerChoice: SNASLayerChoice,
InputChoice : SNASInputChoice InputChoice: SNASInputChoice
} }
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Dict, Any, Optional
import random import random
import pytorch_lightning as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from nni.retiarii.nn.pytorch.api import LayerChoice, InputChoice from nni.retiarii.nn.pytorch.api import LayerChoice, InputChoice
from .random import PathSamplingLayerChoice, PathSamplingInputChoice from .random import PathSamplingLayerChoice, PathSamplingInputChoice
from .base_lightning import BaseOneShotLightningModule from .base_lightning import BaseOneShotLightningModule, ReplaceDictType
from .enas import ReinforceController, ReinforceField from .enas import ReinforceController, ReinforceField
class EnasModule(BaseOneShotLightningModule): class EnasModule(BaseOneShotLightningModule):
""" _enas_note = """
The ENAS module. There are 2 steps in an epoch. 1: training model parameters. 2: training ENAS RL agent. The agent will produce The implementation of ENAS :cite:p:`pham2018efficient`. There are 2 steps in an epoch.
a sample of model architecture to get the best reward. Firstly, training model parameters.
The ENASModule should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`. Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward.
{{module_notes}}
Parameters Parameters
---------- ----------
base_model : pl.LightningModule {{module_params}}
he evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will {base_params}
be wrapped by this model.
ctrl_kwargs : dict ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`. Optional kwargs that will be passed to :class:`ReinforceController`.
entropy_weight : float entropy_weight : float
...@@ -33,22 +37,25 @@ class EnasModule(BaseOneShotLightningModule): ...@@ -33,22 +37,25 @@ class EnasModule(BaseOneShotLightningModule):
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_steps_aggregate : int ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller. Number of steps that will be aggregated into one mini-batch for RL controller.
grad_clip : float ctrl_grad_clip : float
Gradient clipping value. Gradient clipping value of controller.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None """.format(base_params=BaseOneShotLightningModule._custom_replace_dict_note)
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method. __doc__ = _enas_note.format(
module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.',
Reference module_params=BaseOneShotLightningModule._inner_module_note,
---------- )
.. [enas] H. Pham, M. Guan, B. Zoph, Q. Le, and J. Dean, “Efficient Neural Architecture Search via Parameters Sharing,”
in Proceedings of the 35th International Conference on Machine Learning, Jul. 2018, pp. 4095-4104. def __init__(self,
Available: https://proceedings.mlr.press/v80/pham18a.html inner_module: pl.LightningModule,
""" ctrl_kwargs: Dict[str, Any] = None,
def __init__(self, base_model, ctrl_kwargs = None, entropy_weight: float = 1e-4,
entropy_weight = 1e-4, skip_weight = .8, baseline_decay = .999, skip_weight: float = .8,
ctrl_steps_aggregate = 20, grad_clip = 0, custom_replace_dict = None): baseline_decay: float = .999,
super().__init__(base_model, custom_replace_dict) ctrl_steps_aggregate: float = 20,
ctrl_grad_clip: float = 0,
custom_replace_dict: Optional[ReplaceDictType] = None):
super().__init__(inner_module, custom_replace_dict)
self.nas_fields = [ReinforceField(name, len(module), self.nas_fields = [ReinforceField(name, len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1) isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1)
...@@ -60,7 +67,7 @@ class EnasModule(BaseOneShotLightningModule): ...@@ -60,7 +67,7 @@ class EnasModule(BaseOneShotLightningModule):
self.baseline_decay = baseline_decay self.baseline_decay = baseline_decay
self.baseline = 0. self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.grad_clip = grad_clip self.ctrl_grad_clip = ctrl_grad_clip
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4) return optim.Adam(self.controller.parameters(), lr=3.5e-4)
...@@ -116,8 +123,8 @@ class EnasModule(BaseOneShotLightningModule): ...@@ -116,8 +123,8 @@ class EnasModule(BaseOneShotLightningModule):
self.manual_backward(rnn_step_loss) self.manual_backward(rnn_step_loss)
if (batch_idx + 1) % self.ctrl_steps_aggregate == 0: if (batch_idx + 1) % self.ctrl_steps_aggregate == 0:
if self.grad_clip > 0: if self.ctrl_grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.grad_clip) nn.utils.clip_grad_norm_(self.controller.parameters(), self.ctrl_grad_clip)
arc_opt.step() arc_opt.step()
arc_opt.zero_grad() arc_opt.zero_grad()
...@@ -135,20 +142,22 @@ class EnasModule(BaseOneShotLightningModule): ...@@ -135,20 +142,22 @@ class EnasModule(BaseOneShotLightningModule):
return self.controller.resample() return self.controller.resample()
class RandomSampleModule(BaseOneShotLightningModule): class RandomSamplingModule(BaseOneShotLightningModule):
""" _random_note = """
Random Sampling NAS Algorithm. In each epoch, model parameters are trained after a uniformly random sampling of each choice. Random Sampling NAS Algorithm.
The training result is also a random sample of the search space. In each epoch, model parameters are trained after a uniformly random sampling of each choice.
Notably, the exporting result is **also a random sample** of the search space.
Parameters Parameters
---------- ----------
base_model : pl.LightningModule {{module_params}}
he evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will {base_params}
be wrapped by this model. """.format(base_params=BaseOneShotLightningModule._custom_replace_dict_note)
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom __doc__ = _random_note.format(
replace dict will override the default replace dict of each NAS method. module_params=BaseOneShotLightningModule._inner_module_note,
""" )
automatic_optimization = True automatic_optimization = True
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Strategy integration of one-shot.
This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.retiarii.strategy``.
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
"""
import warnings
from typing import Any, List, Optional, Type, Union, Tuple
import torch.nn as nn
from torch.utils.data import DataLoader
from nni.retiarii.graph import Model
from nni.retiarii.strategy.base import BaseStrategy
from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule
from .base_lightning import BaseOneShotLightningModule
from .differentiable import DartsModule, ProxylessModule, SnasModule
from .sampling import EnasModule, RandomSamplingModule
from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader
class OneShotStrategy(BaseStrategy):
"""Wrap an one-shot lightning module as a one-shot strategy."""
def __init__(self, oneshot_module: Type[BaseOneShotLightningModule], **kwargs):
self.oneshot_module = oneshot_module
self.oneshot_kwargs = kwargs
self.model: Optional[BaseOneShotLightningModule] = None
def _get_dataloader(self, train_dataloader: DataLoader, val_dataloaders: DataLoader) \
-> Union[DataLoader, Tuple[DataLoader, DataLoader]]:
"""
One-shot strategy typically requires a customized dataloader.
If only train dataloader is produced, return one dataloader.
Otherwise, return train dataloader and valid loader as a tuple.
"""
raise NotImplementedError()
def run(self, base_model: Model, applied_mutators):
# one-shot strategy doesn't use ``applied_mutators``
# but get the "mutators" on their own
_reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
py_model: nn.Module = base_model.python_object
if not isinstance(py_model, nn.Module):
raise TypeError('Model is not a nn.Module. ' + _reason)
if applied_mutators:
raise ValueError('Mutator is not empty. ' + _reason)
if not isinstance(base_model.evaluator, Lightning):
raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.')
evaluator_module: LightningModule = base_model.evaluator.module
evaluator_module.set_model(py_model)
self.model: BaseOneShotLightningModule = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
evaluator: Lightning = base_model.evaluator
dataloader = self._get_dataloader(evaluator.train_dataloader, evaluator.val_dataloaders)
if isinstance(dataloader, tuple):
dataloader, val_loader = dataloader
evaluator.trainer.fit(self.model, dataloader, val_loader)
else:
evaluator.trainer.fit(self.model, dataloader)
def export_top_models(self, top_k: int = 1) -> List[Any]:
if self.model is None:
raise RuntimeError('One-shot strategy needs to be run before export.')
if top_k != 1:
warnings.warn('One-shot strategy currently only supports exporting top-1 model.', RuntimeWarning)
return [self.model.export()]
class DARTS(OneShotStrategy):
__doc__ = DartsModule._darts_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(DartsModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders)
class Proxyless(OneShotStrategy):
__doc__ = ProxylessModule._proxyless_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(EnasModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders)
class SNAS(OneShotStrategy):
__doc__ = SnasModule._snas_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(SnasModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders)
class ENAS(OneShotStrategy):
__doc__ = EnasModule._enas_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(EnasModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return ConcatenateTrainValDataLoader(train_dataloader, val_dataloaders)
class RandomOneShot(OneShotStrategy):
__doc__ = RandomSamplingModule._random_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(RandomSamplingModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders):
return train_dataloader, val_dataloaders
...@@ -7,3 +7,4 @@ from .evolution import RegularizedEvolution ...@@ -7,3 +7,4 @@ from .evolution import RegularizedEvolution
from .tpe_strategy import TPEStrategy from .tpe_strategy import TPEStrategy
from .local_debug_strategy import _LocalDebugStrategy from .local_debug_strategy import _LocalDebugStrategy
from .rl import PolicyBasedRL from .rl import PolicyBasedRL
from .oneshot import DARTS, Proxyless, SNAS, ENAS, RandomOneShot
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import abc import abc
from typing import List from typing import List, Any
from ..graph import Model from ..graph import Model
from ..mutator import Mutator from ..mutator import Mutator
...@@ -13,3 +13,6 @@ class BaseStrategy(abc.ABC): ...@@ -13,3 +13,6 @@ class BaseStrategy(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None: def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass pass
def export_top_models(self) -> List[Any]:
raise NotImplementedError('"export_top_models" is not implemented.')
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import BaseStrategy
try:
from nni.retiarii.oneshot.pytorch.strategy import ( # pylint: disable=unused-import
DARTS, SNAS, Proxyless, ENAS, RandomOneShot
)
except ImportError as import_err:
_import_err = import_err
class ImportFailedStrategy(BaseStrategy):
def run(self, base_model, applied_mutators):
raise _import_err
# otherwise typing check will pointing to the wrong location
globals()['DARTS'] = ImportFailedStrategy
globals()['SNAS'] = ImportFailedStrategy
globals()['Proxyless'] = ImportFailedStrategy
globals()['ENAS'] = ImportFailedStrategy
globals()['RandomOneShot'] = ImportFailedStrategy
...@@ -8,12 +8,10 @@ from torchvision import transforms ...@@ -8,12 +8,10 @@ from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torch.utils.data.sampler import RandomSampler from torch.utils.data.sampler import RandomSampler
from nni.retiarii import strategy, model_wrapper
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.evaluator.pytorch.lightning import Classification, DataLoader from nni.retiarii.evaluator.pytorch.lightning import Classification, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from nni.retiarii.oneshot.pytorch import (ConcatenateTrainValDataLoader,
DartsModule, EnasModule, SNASModule,
InterleavedTrainValDataLoader,
ProxylessModule, RandomSampleModule)
class DepthwiseSeparableConv(nn.Module): class DepthwiseSeparableConv(nn.Module):
...@@ -26,6 +24,7 @@ class DepthwiseSeparableConv(nn.Module): ...@@ -26,6 +24,7 @@ class DepthwiseSeparableConv(nn.Module):
return self.pointwise(self.depthwise(x)) return self.pointwise(self.depthwise(x))
@model_wrapper
class Net(pl.LightningModule): class Net(pl.LightningModule):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -68,7 +67,6 @@ class Net(pl.LightningModule): ...@@ -68,7 +67,6 @@ class Net(pl.LightningModule):
return output return output
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def prepare_model_data(): def prepare_model_data():
base_model = Net() base_model = Net()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
...@@ -86,53 +84,42 @@ def prepare_model_data(): ...@@ -86,53 +84,42 @@ def prepare_model_data():
return base_model, train_loader, valid_loader, trainer_kwargs return base_model, train_loader, valid_loader, trainer_kwargs
def _test_strategy(strategy_):
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
experiment = RetiariiExperiment(base_model, cls, strategy=strategy_)
config = RetiariiExeConfig()
config.execution_engine = 'oneshot'
experiment.run(config)
assert isinstance(experiment.export_top_models()[0], dict)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_darts(): def test_darts():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data() _test_strategy(strategy.DARTS())
cls = Classification(train_dataloader=train_loader, val_dataloaders = valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
darts_model = DartsModule(cls.module)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(darts_model, para_loader)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_proxyless(): def test_proxyless():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data() _test_strategy(strategy.Proxyless())
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
proxyless_model = ProxylessModule(cls.module)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(proxyless_model, para_loader)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_enas(): def test_enas():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data() _test_strategy(strategy.ENAS())
cls = Classification(train_dataloader = train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
enas_model = EnasModule(cls.module)
concat_loader = ConcatenateTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(enas_model, concat_loader)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_random(): def test_random():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data() _test_strategy(strategy.RandomOneShot())
cls = Classification(train_dataloader = train_loader, val_dataloaders=valid_loader , **trainer_kwargs)
cls.module.set_model(base_model)
random_model = RandomSampleModule(cls.module)
cls.trainer.fit(random_model, cls.train_dataloader, cls.val_dataloaders)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_snas(): def test_snas():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data() _test_strategy(strategy.SNAS())
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
proxyless_model = SNASModule(cls.module, 1, use_temp_anneal=True)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(proxyless_model, para_loader)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment