Unverified Commit 89786596 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

NAS documents general improvements (v2.8) (#4942)

parent b99e2683
...@@ -8,7 +8,7 @@ A model evaluator is for training and validating each generated model. They are ...@@ -8,7 +8,7 @@ A model evaluator is for training and validating each generated model. They are
Customize Evaluator with Any Function Customize Evaluator with Any Function
------------------------------------- -------------------------------------
The simplest way to customize a new evaluator is with :class:`FunctionalEvaluator <nni.retiarii.evaluator.FunctionalEvaluator>`, which is very easy when training code is already available. Users only need to write a fit function that wraps everything, which usually includes training, validating and testing of a single model. This function takes one positional arguments (``model_cls``) and possible keyword arguments. The keyword arguments (other than ``model_cls``) are fed to :class:`FunctionalEvaluator <nni.retiarii.evaluator.FunctionalEvaluator>` as its initialization parameters (note that they will be :doc:`serialized <./serialization>`). In this way, users get everything under their control, but expose less information to the framework and as a result, further optimizations like :ref:`CGO <cgo-execution-engine>` might be not feasible. An example is as belows: The simplest way to customize a new evaluator is with :class:`~nni.retiarii.evaluator.FunctionalEvaluator`, which is very easy when training code is already available. Users only need to write a fit function that wraps everything, which usually includes training, validating and testing of a single model. This function takes one positional arguments (``model_cls``) and possible keyword arguments. The keyword arguments (other than ``model_cls``) are fed to :class:`~nni.retiarii.evaluator.FunctionalEvaluator` as its initialization parameters (note that they will be :doc:`serialized <./serialization>`). In this way, users get everything under their control, but expose less information to the framework and as a result, further optimizations like :ref:`CGO <cgo-execution-engine>` might be not feasible. An example is as belows:
.. code-block:: python .. code-block:: python
...@@ -42,6 +42,41 @@ The simplest way to customize a new evaluator is with :class:`FunctionalEvaluato ...@@ -42,6 +42,41 @@ The simplest way to customize a new evaluator is with :class:`FunctionalEvaluato
If the conversion is successful, the model will be able to be visualized with powerful tools `Netron <https://netron.app/>`__. If the conversion is successful, the model will be able to be visualized with powerful tools `Netron <https://netron.app/>`__.
Use Evaluators to Train and Evaluate Models
-------------------------------------------
Users can use evaluators to train or evaluate a single, concrete architecture. This is very useful when:
* Debugging your evaluator against a baseline model.
* Fully train, validate and test your model after the search process is complete.
The usage is shown below:
.. code-block:: python
# Class definition of single model, for example, ResNet.
class SingleModel(nn.Module):
def __init__(): # Can't have init parameters here.
...
# Use a callable returning a model
evaluator.evaluate(SingleModel)
# Or initialize the model beforehand
evaluator.evaluate(SingleModel())
The underlying implementation of :meth:`~nni.retiarii.Evaluator.evaluate` depends on concrete evaluator that you used.
For example, if :class:`~nni.retiarii.evaluator.FunctionalEvaluator` is used, it will run your customized fit function.
If lightning evaluators like :class:`nni.retiarii.evaluator.pytorch.Classification` are used, it will invoke the ``trainer.fit()`` of Lightning.
To evaluate an architecture that is exported from experiment (i.e., from :meth:`~nni.retiarii.experiment.pytorch.RetiariiExperiment.export_top_models`), use :func:`nni.retiarii.fixed_arch` to instantiate the exported model::
with fixed_arch(exported_model):
model = ModelSpace()
# Then use evaluator.evaluate
evaluator.evaluate(model)
.. tip:: There is a way to port the trained checkpoint of super-net produced by one-shot strategies, to the concrete chosen architecture, thanks to :func:`nni.retiarii.utils.original_state_dict_hooks`. This is helpful in implementing recent multi-stage NAS algorithms like `SPOS <https://arxiv.org/abs/1904.00420>`__.
.. _lightning-evaluator: .. _lightning-evaluator:
Evaluators with PyTorch-Lightning Evaluators with PyTorch-Lightning
...@@ -134,6 +169,10 @@ An example is as follows: ...@@ -134,6 +169,10 @@ An example is as follows:
if stage == 'fit': if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_loss'].item()) nni.report_final_result(self.trainer.callback_metrics['val_loss'].item())
.. note::
If you are trying to use your customized evaluator with one-shot strategy, bear in mind that your defined methods will be reassembled into another LightningModule, which might result in extra constraints when writing the LightningModule. For example, your validation step could appear else where (e.g., in ``training_step``). This prohibits you from returning arbitrary object in ``validation_step``.
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a :class:`nni.retiarii.evaluator.pytorch.Lightning` object, and pass this object into a Retiarii experiment. Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a :class:`nni.retiarii.evaluator.pytorch.Lightning` object, and pass this object into a Retiarii experiment.
.. code-block:: python .. code-block:: python
......
...@@ -75,7 +75,15 @@ Starting from v2.8, the usage of one-shot strategies are much alike to multi-tri ...@@ -75,7 +75,15 @@ Starting from v2.8, the usage of one-shot strategies are much alike to multi-tri
import nni.retiarii.strategy as strategy import nni.retiarii.strategy as strategy
import nni.retiarii.evaluator.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
evaluator = pl.Classification(...) evaluator = pl.Classification(
# Need to use `pl.DataLoader` instead of `torch.utils.data.DataLoader` here,
# or use `nni.trace` to wrap `torch.utils.data.DataLoader`.
train_dataloaders=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
# Other keyword arguments passed to pytorch_lightning.Trainer.
max_epochs=10,
gpus=1,
)
exploration_strategy = strategy.DARTS() exploration_strategy = strategy.DARTS()
exp_config.execution_engine = 'oneshot' exp_config.execution_engine = 'oneshot'
......
...@@ -29,9 +29,14 @@ Utilities ...@@ -29,9 +29,14 @@ Utilities
Customization Customization
------------- -------------
.. autoclass:: nni.retiarii.Evaluator
:members:
.. autoclass:: nni.retiarii.evaluator.pytorch.Lightning .. autoclass:: nni.retiarii.evaluator.pytorch.Lightning
:members:
.. autoclass:: nni.retiarii.evaluator.pytorch.LightningModule .. autoclass:: nni.retiarii.evaluator.pytorch.LightningModule
:members:
Cross-graph Optimization (experimental) Cross-graph Optimization (experimental)
--------------------------------------- ---------------------------------------
......
...@@ -79,6 +79,11 @@ class Lightning(Evaluator): ...@@ -79,6 +79,11 @@ class Lightning(Evaluator):
in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name
and type depend on the specific task. and type depend on the specific task.
.. warning::
The Lightning evaluator are stateful. If you try to use a previous Lightning evaluator,
please note that the inner ``lightning_module`` and ``trainer`` will be reused.
Parameters Parameters
---------- ----------
lightning_module lightning_module
......
...@@ -324,6 +324,9 @@ class RetiariiExperiment(Experiment): ...@@ -324,6 +324,9 @@ class RetiariiExperiment(Experiment):
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are
available for customization. available for customization.
The concrete behavior of export depends on each strategy.
See the documentation of each strategy for detailed specifications.
Parameters Parameters
---------- ----------
top_k : int top_k : int
......
...@@ -19,7 +19,7 @@ if TYPE_CHECKING: ...@@ -19,7 +19,7 @@ if TYPE_CHECKING:
from .operation import Cell, Operation, _IOPseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid from .utils import uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData'] __all__ = ['Evaluator', 'Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']
MetricData = Any MetricData = Any
...@@ -43,6 +43,13 @@ class Evaluator(abc.ABC): ...@@ -43,6 +43,13 @@ class Evaluator(abc.ABC):
For example, functional evaluator might directly import the function and call the function. For example, functional evaluator might directly import the function and call the function.
""" """
def evaluate(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
"""To run evaluation of a model. The model could be either a concrete model or a callable returning a model.
The concrete implementation of evaluate depends on the implementation of ``_execute()`` in sub-class.
"""
return self._execute(model_cls)
def __repr__(self): def __repr__(self):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()]) items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})' return f'{self.__class__.__name__}({items})'
...@@ -355,6 +362,7 @@ class Graph: ...@@ -355,6 +362,7 @@ class Graph:
@overload @overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ... def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload @overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ... parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
......
...@@ -157,7 +157,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -157,7 +157,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
Mutation hooks are callable that inputs an Module and returns a Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`. :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules. They are invoked in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently, For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks. the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module, The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
...@@ -189,7 +189,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -189,7 +189,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
it means the hook suggests to it means the hook suggests to
keep the module unchanged, and nothing will happen. keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`no_default_hook`. An example of mutation hook is given in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving However it's recommended to implement mutation hooks by deriving
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`, :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list. and add its classmethod ``mutate`` to this list.
......
...@@ -29,7 +29,8 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -29,7 +29,8 @@ class DartsLightningModule(BaseOneShotLightningModule):
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained. The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained. The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
The current implementation is for DARTS in first order. Second order (unrolled) is not supported yet. The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet.
.. versionadded:: 2.8 .. versionadded:: 2.8
...@@ -186,8 +187,9 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -186,8 +187,9 @@ class GumbelDartsLightningModule(DartsLightningModule):
See `FBNet <https://arxiv.org/abs/1812.03443>`__ and `SNAS <https://arxiv.org/abs/1812.09926>`__. See `FBNet <https://arxiv.org/abs/1812.03443>`__ and `SNAS <https://arxiv.org/abs/1812.09926>`__.
This is a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution. This is a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
Essentially, it samples one path on forward, Essentially, it tries to mimick the behavior of sampling one path on forward by gradually
and implements its own backward to update the architecture parameters based on only one path. cool down the temperature, aiming to bridge the gap between differentiable architecture weights and
discretization of architectures.
.. versionadded:: 2.8 .. versionadded:: 2.8
......
...@@ -68,7 +68,9 @@ class ReinforceController(nn.Module): ...@@ -68,7 +68,9 @@ class ReinforceController(nn.Module):
tanh_constant : float tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``. Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float skip_target : float
Target probability that skipconnect will appear. Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float temperature : float
Temperature constant that divides the logits. Temperature constant that divides the logits.
entropy_reduction : str entropy_reduction : str
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Experimental version of sampling-based one-shot implementation.""" """Experimental version of sampling-based one-shot implementation."""
from __future__ import annotations from __future__ import annotations
import warnings
from typing import Any from typing import Any
import pytorch_lightning as pl import pytorch_lightning as pl
...@@ -76,6 +77,18 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule): ...@@ -76,6 +77,18 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
self.resample() self.resample()
return self.model.training_step(batch, batch_idx) return self.model.training_step(batch, batch_idx)
def export(self) -> dict[str, Any]:
"""
Export of Random one-shot. It will return an arbitrary architecture.
"""
warnings.warn(
'Direct export from RandomOneShot returns an arbitrary architecture. '
'Sampling the best architecture from this trained supernet is another search process. '
'Users need to do another search based on the checkpoint of the one-shot strategy.',
UserWarning
)
return super().export()
class EnasLightningModule(RandomSamplingLightningModule): class EnasLightningModule(RandomSamplingLightningModule):
_enas_note = """ _enas_note = """
...@@ -86,8 +99,10 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -86,8 +99,10 @@ class EnasLightningModule(RandomSamplingLightningModule):
- Firstly, training model parameters. - Firstly, training model parameters.
- Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward. - Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward.
ENAS requires the evaluator to report metrics via ``self.log`` in its ``validation_step``. .. note::
See explanation of ``reward_metric_name`` for details.
ENAS requires the evaluator to report metrics via ``self.log`` in its ``validation_step``.
See explanation of ``reward_metric_name`` for details.
The supported mutation primitives of ENAS are: The supported mutation primitives of ENAS are:
...@@ -105,22 +120,24 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -105,22 +120,24 @@ class EnasLightningModule(RandomSamplingLightningModule):
{{module_params}} {{module_params}}
{base_params} {base_params}
ctrl_kwargs : dict ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`. Optional kwargs that will be passed to :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController`.
entropy_weight : float entropy_weight : float
Weight of sample entropy loss. Weight of sample entropy loss in RL.
skip_weight : float skip_weight : float
Weight of skip penalty loss. Weight of skip penalty loss. See :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController` for details.
baseline_decay : float baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. Decay factor of reward baseline, which is used to normalize the reward in RL.
At each step, the new reward baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_steps_aggregate : int ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller. Number of steps for which the gradients will be accumulated,
before updating the weights of RL controller.
ctrl_grad_clip : float ctrl_grad_clip : float
Gradient clipping value of controller. Gradient clipping value of controller.
reward_metric_name : str or None reward_metric_name : str or None
The name of the metric which is treated as reward. The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator. This will be not effective when there's only one metric returned from evaluator.
If there are multiple, it will find the metric with key name ``reward_metric_name``, If there are multiple, by default, it will find the metric with key name ``default``.
which is "default" by default. If reward_metric_name is specified, it will find reward_metric_name.
Otherwise it raises an exception indicating multiple metrics are found. Otherwise it raises an exception indicating multiple metrics are found.
""".format( """.format(
base_params=BaseOneShotLightningModule._mutation_hooks_note, base_params=BaseOneShotLightningModule._mutation_hooks_note,
......
...@@ -76,6 +76,7 @@ class OneShotStrategy(BaseStrategy): ...@@ -76,6 +76,7 @@ class OneShotStrategy(BaseStrategy):
evaluator.trainer.fit(self.model, train_loader, val_loader) evaluator.trainer.fit(self.model, train_loader, val_loader)
def export_top_models(self, top_k: int = 1) -> list[Any]: def export_top_models(self, top_k: int = 1) -> list[Any]:
"""The behavior of export top models in strategy depends on the implementation of inner one-shot module."""
if self.model is None: if self.model is None:
raise RuntimeError('One-shot strategy needs to be run before export.') raise RuntimeError('One-shot strategy needs to be run before export.')
if top_k != 1: if top_k != 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment