"vscode:/vscode.git/clone" did not exist on "d9f71ab3c3cc162226ec1c9945fef1a5faf4c512"
Unverified Commit ebd56271 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Doc] NAS improvements (#4613)

parent b52f7756
...@@ -3,5 +3,8 @@ build/ ...@@ -3,5 +3,8 @@ build/
# legacy build # legacy build
_build/ _build/
# ignored copied rst in tutorials
**/tutorials/cp_*.rst
# auto-generated reference table # auto-generated reference table
_modules/ _modules/
"""Creating hard links for tutorials in each individual topics."""
import os
import re
cp_list = {
'tutorials/hello_nas.rst': 'tutorials/cp_hello_nas_quickstart.rst'
}
HEADER = """.. THIS FILE IS A COPY OF {} WITH MODIFICATIONS.
.. TO MAKE ONE TUTORIAL APPEAR IN MULTIPLE PLACES.
"""
def copy_tutorials(app):
# TODO: use sphinx logger
print('[tutorial links] copy tutorials...')
for src, tar in cp_list.items():
target_path = os.path.join(app.srcdir, tar)
content = open(os.path.join(app.srcdir, src)).read()
# Add a header
content = HEADER.format(src) + content
# Add a prefix to labels to avoid duplicates.
label_map = {}
for prefix, label_name in list(re.findall(r'(\.\.\s*_)(.*?)\:', content)):
label_map[label_name] = 'tutorial_cp_' + label_name
# anchor
content = content.replace(prefix + label_name + ':', prefix + label_map[label_name] + ':')
# :ref:`xxx`
content = content.replace(f':ref:`{label_name}`', f':ref:`{label_map[label_name]}')
# :ref:`yyy <xxx>`
content = re.sub(r"(\:ref\:`.*?\<)" + label_name + r"(\>`)", r'\1' + label_map[label_name] + r'\2', content)
open(target_path, 'w').write(content)
def setup(app):
# See life-cycle of sphinx app here:
# https://www.sphinx-doc.org/en/master/extdev/appapi.html#sphinx-core-events
app.connect('builder-inited', copy_tutorials)
...@@ -57,6 +57,7 @@ extensions = [ ...@@ -57,6 +57,7 @@ extensions = [
'IPython.sphinxext.ipython_console_highlighting', 'IPython.sphinxext.ipython_console_highlighting',
# Custom extensions in extension/ folder. # Custom extensions in extension/ folder.
'tutorial_links', # this has to be after sphinx-gallery
'inplace_translation', 'inplace_translation',
'cardlinkitem', 'cardlinkitem',
'patch_docutils', 'patch_docutils',
......
...@@ -42,6 +42,7 @@ ValueChoice ...@@ -42,6 +42,7 @@ ValueChoice
.. autoclass:: nni.retiarii.nn.pytorch.ValueChoice .. autoclass:: nni.retiarii.nn.pytorch.ValueChoice
:members: :members:
:inherited-members: Module
.. _nas-repeat: .. _nas-repeat:
......
...@@ -52,8 +52,8 @@ NNI provides some commonly used model evaluators for users' convenience. These e ...@@ -52,8 +52,8 @@ NNI provides some commonly used model evaluators for users' convenience. These e
We recommend to read the `serialization tutorial <./Serialization.rst>`__ before using these evaluators. A few notes to summarize the tutorial: We recommend to read the `serialization tutorial <./Serialization.rst>`__ before using these evaluators. A few notes to summarize the tutorial:
1. ``pl.DataLoader`` should be used in place of ``torch.utils.data.DataLoader``. 1. :class:`nni.retarii.evaluator.pytorch.DataLoader`` should be used in place of ``torch.utils.data.DataLoader``.
2. The datasets used in data-loader should be decorated with ``nni.trace`` recursively. 2. The datasets used in data-loader should be decorated with :meth:`nni.trace` recursively.
For example, For example,
...@@ -76,7 +76,7 @@ Customize Evaluator with PyTorch-Lightning ...@@ -76,7 +76,7 @@ Customize Evaluator with PyTorch-Lightning
Another approach is to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>`__ to learn the basic concepts and components provided by PyTorch-lightning. Another approach is to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>`__ to learn the basic concepts and components provided by PyTorch-lightning.
In practice, writing a new training module in Retiarii should inherit ``nni.retiarii.evaluator.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Evaluators should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively. In practice, writing a new training module in Retiarii should inherit :class:`nni.retiarii.evaluator.pytorch.LightningModule`, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Evaluators should also communicate with strategies via two API calls (:meth:`nni.report_intermediate_result` for periodical metrics and :meth:`nni.report_final_result` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively.
An example is as follows: An example is as follows:
...@@ -129,7 +129,7 @@ An example is as follows: ...@@ -129,7 +129,7 @@ An example is as follows:
if stage == 'fit': if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_loss'].item()) nni.report_final_result(self.trainer.callback_metrics['val_loss'].item())
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a ``Lightning`` object, and pass this object into a Retiarii experiment. Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a :class:`nni.retiarii.evaluator.pytorch.Lightning` object, and pass this object into a Retiarii experiment.
.. code-block:: python .. code-block:: python
......
...@@ -5,7 +5,7 @@ Retiarii for Neural Architecture Search ...@@ -5,7 +5,7 @@ Retiarii for Neural Architecture Search
:hidden: :hidden:
:titlesonly: :titlesonly:
Quick Start <../tutorials/hello_nas> Quick Start <../tutorials/cp_hello_nas_quickstart>
construct_space construct_space
exploration_strategy exploration_strategy
evaluator evaluator
...@@ -60,25 +60,25 @@ The following APIs are provided to ease the engineering effort of writing a new ...@@ -60,25 +60,25 @@ The following APIs are provided to ease the engineering effort of writing a new
- Category - Category
- Brief Description - Brief Description
* - :ref:`nas-layer-choice` * - :ref:`nas-layer-choice`
- :ref:`Multi-trial <multi-trial-nas>` - :ref:`Mutation Primitives <mutation-primitives>`
- Select from some PyTorch modules - Select from some PyTorch modules
* - :ref:`nas-input-choice` * - :ref:`nas-input-choice`
- :ref:`Multi-trial <multi-trial-nas>` - :ref:`Mutation Primitives <mutation-primitives>`
- Select from some inputs (tensors) - Select from some inputs (tensors)
* - :ref:`nas-value-choice` * - :ref:`nas-value-choice`
- :ref:`Multi-trial <multi-trial-nas>` - :ref:`Mutation Primitives <mutation-primitives>`
- Select from some candidate values - Select from some candidate values
* - :ref:`nas-repeat` * - :ref:`nas-repeat`
- :ref:`Multi-trial <multi-trial-nas>` - :ref:`Mutation Primitives <mutation-primitives>`
- Repeat a block by a variable number of times - Repeat a block by a variable number of times
* - :ref:`nas-cell` * - :ref:`nas-cell`
- :ref:`Multi-trial <multi-trial-nas>` - :ref:`Mutation Primitives <mutation-primitives>`
- Cell structure popularly used in literature - Cell structure popularly used in literature
* - :ref:`nas-cell-101` * - :ref:`nas-cell-101`
- :ref:`Multi-trial <multi-trial-nas>` - :ref:`Mutation Primitives <mutation-primitives>`
- Cell structure (variant) proposed by NAS-Bench-101 - Cell structure (variant) proposed by NAS-Bench-101
* - :ref:`nas-cell-201` * - :ref:`nas-cell-201`
- :ref:`Multi-trial <multi-trial-nas>` - :ref:`Mutation Primitives <mutation-primitives>`
- Cell structure (variant) proposed by NAS-Bench-201 - Cell structure (variant) proposed by NAS-Bench-201
* - :ref:`nas-autoactivation` * - :ref:`nas-autoactivation`
- :ref:`Hyper-modules <hyper-modules>` - :ref:`Hyper-modules <hyper-modules>`
......
...@@ -32,6 +32,11 @@ nni.retiarii.evaluator ...@@ -32,6 +32,11 @@ nni.retiarii.evaluator
.. automodule:: nni.retiarii.evaluator.pytorch .. automodule:: nni.retiarii.evaluator.pytorch
:imported-members: :imported-members:
:members: :members:
:exclude-members: Trainer, DataLoader
.. autoclass:: nni.retiarii.evaluator.pytorch.Trainer
.. autoclass:: nni.retiarii.evaluator.pytorch.DataLoader
nni.retiarii.execution nni.retiarii.execution
---------------------- ----------------------
......
...@@ -52,3 +52,13 @@ nav.md-tabs .md-tabs__item:not(:last-child) .md-tabs__link:after { ...@@ -52,3 +52,13 @@ nav.md-tabs .md-tabs__item:not(:last-child) .md-tabs__link:after {
.citation dt { .citation dt {
padding-right: 1em; padding-right: 1em;
} }
/* fixes reference overlapping issue */
/* This is originally defined to be negative in application_fixes.css */
/* They did that to ensure the header doesn't disappear in jump links */
/* We did this by using scroll-margin-top instead */
dt:target {
margin-top: 0.15rem !important;
padding-top: 0 !important;
scroll-margin-top: 3.5rem;
}
...@@ -5,6 +5,7 @@ import copy ...@@ -5,6 +5,7 @@ import copy
import functools import functools
import inspect import inspect
import numbers import numbers
import os
import types import types
import warnings import warnings
from io import IOBase from io import IOBase
...@@ -235,6 +236,13 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable] ...@@ -235,6 +236,13 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
pass pass
""" """
# This is an internal flag to control the behavior of trace.
# Useful in doc build and tests.
# Might be changed in future.
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls_or_func
def wrap(cls_or_func): def wrap(cls_or_func):
# already annotated, do nothing # already annotated, do nothing
if getattr(cls_or_func, '_traced', False): if getattr(cls_or_func, '_traced', False):
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Union, Optional, List, Type from typing import Dict, Union, Optional, List, Callable
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn as nn import torch.nn as nn
...@@ -29,11 +29,20 @@ __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classificat ...@@ -29,11 +29,20 @@ __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classificat
class LightningModule(pl.LightningModule): class LightningModule(pl.LightningModule):
""" """
Basic wrapper of generated model. Basic wrapper of generated model.
Lightning modules used in NNI should inherit this class. Lightning modules used in NNI should inherit this class.
It's a subclass of ``pytorch_lightning.LightningModule``.
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
""" """
def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> None: def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
"""Set the inner model (architecture) to train / evaluate.
Parameters
----------
model : callable or nn.Module
Can be a callable returning nn.Module or nn.Module.
"""
if isinstance(model, nn.Module): if isinstance(model, nn.Module):
self.model = model self.model = model
else: else:
...@@ -41,9 +50,13 @@ class LightningModule(pl.LightningModule): ...@@ -41,9 +50,13 @@ class LightningModule(pl.LightningModule):
Trainer = nni.trace(pl.Trainer) Trainer = nni.trace(pl.Trainer)
Trainer.__doc__ = 'Traced version of ``pytorch_lightning.Trainer``.' Trainer.__doc__ = """
Traced version of ``pytorch_lightning.Trainer``. See https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
"""
DataLoader = nni.trace(torch_data.DataLoader) DataLoader = nni.trace(torch_data.DataLoader)
DataLoader.__doc__ = 'Traced version of ``torch.utils.data.DataLoader``.' DataLoader.__doc__ = """
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
"""
@nni.trace @nni.trace
class Lightning(Evaluator): class Lightning(Evaluator):
...@@ -238,7 +251,7 @@ class _ClassificationModule(_SupervisedLearningModule): ...@@ -238,7 +251,7 @@ class _ClassificationModule(_SupervisedLearningModule):
class Classification(Lightning): class Classification(Lightning):
""" """
Trainer that is used for classification. Evaluator that is used for classification.
Parameters Parameters
---------- ----------
...@@ -291,7 +304,7 @@ class _RegressionModule(_SupervisedLearningModule): ...@@ -291,7 +304,7 @@ class _RegressionModule(_SupervisedLearningModule):
class Regression(Lightning): class Regression(Lightning):
""" """
Trainer that is used for regression. Evaluator that is used for regression.
Parameters Parameters
---------- ----------
......
...@@ -567,7 +567,8 @@ class ValueChoiceX(Translatable): ...@@ -567,7 +567,8 @@ class ValueChoiceX(Translatable):
def __index__(self) -> NoReturn: def __index__(self) -> NoReturn:
# https://docs.python.org/3/reference/datamodel.html#object.__index__ # https://docs.python.org/3/reference/datamodel.html#object.__index__
raise RuntimeError("`__index__` is not allowed on ValueChoice, which means you can't " raise RuntimeError("`__index__` is not allowed on ValueChoice, which means you can't "
"use int(), float(), complex(), range() on a ValueChoice.") "use int(), float(), complex(), range() on a ValueChoice. "
"To cast the type of ValueChoice, please try `ValueChoice.to_int()` or `ValueChoice.to_float()`.")
def __bool__(self) -> NoReturn: def __bool__(self) -> NoReturn:
raise RuntimeError('Cannot use bool() on ValueChoice. That means, using ValueChoice in a if-clause is illegal. ' raise RuntimeError('Cannot use bool() on ValueChoice. That means, using ValueChoice in a if-clause is illegal. '
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import inspect import inspect
import os
import warnings import warnings
from typing import Any, TypeVar, Union from typing import Any, TypeVar, Union
...@@ -64,6 +65,12 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]: ...@@ -64,6 +65,12 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
class PrimitiveOp(nn.Module): class PrimitiveOp(nn.Module):
... ...
""" """
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
_check_wrapped(cls) _check_wrapped(cls)
import torch.nn as nn import torch.nn as nn
...@@ -103,6 +110,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]: ...@@ -103,6 +110,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed. Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
But in future, we might enforce ``@model_wrapper`` to be required for base model. But in future, we might enforce ``@model_wrapper`` to be required for base model.
""" """
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
_check_wrapped(cls) _check_wrapped(cls)
import torch.nn as nn import torch.nn as nn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment