Unverified Commit ebd56271 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Doc] NAS improvements (#4613)

parent b52f7756
......@@ -3,5 +3,8 @@ build/
# legacy build
_build/
# ignored copied rst in tutorials
**/tutorials/cp_*.rst
# auto-generated reference table
_modules/
"""Creating hard links for tutorials in each individual topics."""
import os
import re
cp_list = {
'tutorials/hello_nas.rst': 'tutorials/cp_hello_nas_quickstart.rst'
}
HEADER = """.. THIS FILE IS A COPY OF {} WITH MODIFICATIONS.
.. TO MAKE ONE TUTORIAL APPEAR IN MULTIPLE PLACES.
"""
def copy_tutorials(app):
# TODO: use sphinx logger
print('[tutorial links] copy tutorials...')
for src, tar in cp_list.items():
target_path = os.path.join(app.srcdir, tar)
content = open(os.path.join(app.srcdir, src)).read()
# Add a header
content = HEADER.format(src) + content
# Add a prefix to labels to avoid duplicates.
label_map = {}
for prefix, label_name in list(re.findall(r'(\.\.\s*_)(.*?)\:', content)):
label_map[label_name] = 'tutorial_cp_' + label_name
# anchor
content = content.replace(prefix + label_name + ':', prefix + label_map[label_name] + ':')
# :ref:`xxx`
content = content.replace(f':ref:`{label_name}`', f':ref:`{label_map[label_name]}')
# :ref:`yyy <xxx>`
content = re.sub(r"(\:ref\:`.*?\<)" + label_name + r"(\>`)", r'\1' + label_map[label_name] + r'\2', content)
open(target_path, 'w').write(content)
def setup(app):
# See life-cycle of sphinx app here:
# https://www.sphinx-doc.org/en/master/extdev/appapi.html#sphinx-core-events
app.connect('builder-inited', copy_tutorials)
......@@ -57,6 +57,7 @@ extensions = [
'IPython.sphinxext.ipython_console_highlighting',
# Custom extensions in extension/ folder.
'tutorial_links', # this has to be after sphinx-gallery
'inplace_translation',
'cardlinkitem',
'patch_docutils',
......
......@@ -42,6 +42,7 @@ ValueChoice
.. autoclass:: nni.retiarii.nn.pytorch.ValueChoice
:members:
:inherited-members: Module
.. _nas-repeat:
......
......@@ -52,8 +52,8 @@ NNI provides some commonly used model evaluators for users' convenience. These e
We recommend to read the `serialization tutorial <./Serialization.rst>`__ before using these evaluators. A few notes to summarize the tutorial:
1. ``pl.DataLoader`` should be used in place of ``torch.utils.data.DataLoader``.
2. The datasets used in data-loader should be decorated with ``nni.trace`` recursively.
1. :class:`nni.retarii.evaluator.pytorch.DataLoader`` should be used in place of ``torch.utils.data.DataLoader``.
2. The datasets used in data-loader should be decorated with :meth:`nni.trace` recursively.
For example,
......@@ -76,7 +76,7 @@ Customize Evaluator with PyTorch-Lightning
Another approach is to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the `document of PyTorch-lightning <https://pytorch-lightning.readthedocs.io/>`__ to learn the basic concepts and components provided by PyTorch-lightning.
In practice, writing a new training module in Retiarii should inherit ``nni.retiarii.evaluator.pytorch.lightning.LightningModule``, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Evaluators should also communicate with strategies via two API calls (``nni.report_intermediate_result`` for periodical metrics and ``nni.report_final_result`` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively.
In practice, writing a new training module in Retiarii should inherit :class:`nni.retiarii.evaluator.pytorch.LightningModule`, which has a ``set_model`` that will be called after ``__init__`` to save the candidate model (generated by strategy) as ``self.model``. The rest of the process (like ``training_step``) should be the same as writing any other lightning module. Evaluators should also communicate with strategies via two API calls (:meth:`nni.report_intermediate_result` for periodical metrics and :meth:`nni.report_final_result` for final metrics), added in ``on_validation_epoch_end`` and ``teardown`` respectively.
An example is as follows:
......@@ -129,7 +129,7 @@ An example is as follows:
if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_loss'].item())
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a ``Lightning`` object, and pass this object into a Retiarii experiment.
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a :class:`nni.retiarii.evaluator.pytorch.Lightning` object, and pass this object into a Retiarii experiment.
.. code-block:: python
......
......@@ -5,7 +5,7 @@ Retiarii for Neural Architecture Search
:hidden:
:titlesonly:
Quick Start <../tutorials/hello_nas>
Quick Start <../tutorials/cp_hello_nas_quickstart>
construct_space
exploration_strategy
evaluator
......@@ -60,25 +60,25 @@ The following APIs are provided to ease the engineering effort of writing a new
- Category
- Brief Description
* - :ref:`nas-layer-choice`
- :ref:`Multi-trial <multi-trial-nas>`
- :ref:`Mutation Primitives <mutation-primitives>`
- Select from some PyTorch modules
* - :ref:`nas-input-choice`
- :ref:`Multi-trial <multi-trial-nas>`
- :ref:`Mutation Primitives <mutation-primitives>`
- Select from some inputs (tensors)
* - :ref:`nas-value-choice`
- :ref:`Multi-trial <multi-trial-nas>`
- :ref:`Mutation Primitives <mutation-primitives>`
- Select from some candidate values
* - :ref:`nas-repeat`
- :ref:`Multi-trial <multi-trial-nas>`
- :ref:`Mutation Primitives <mutation-primitives>`
- Repeat a block by a variable number of times
* - :ref:`nas-cell`
- :ref:`Multi-trial <multi-trial-nas>`
- :ref:`Mutation Primitives <mutation-primitives>`
- Cell structure popularly used in literature
* - :ref:`nas-cell-101`
- :ref:`Multi-trial <multi-trial-nas>`
- :ref:`Mutation Primitives <mutation-primitives>`
- Cell structure (variant) proposed by NAS-Bench-101
* - :ref:`nas-cell-201`
- :ref:`Multi-trial <multi-trial-nas>`
- :ref:`Mutation Primitives <mutation-primitives>`
- Cell structure (variant) proposed by NAS-Bench-201
* - :ref:`nas-autoactivation`
- :ref:`Hyper-modules <hyper-modules>`
......
......@@ -32,6 +32,11 @@ nni.retiarii.evaluator
.. automodule:: nni.retiarii.evaluator.pytorch
:imported-members:
:members:
:exclude-members: Trainer, DataLoader
.. autoclass:: nni.retiarii.evaluator.pytorch.Trainer
.. autoclass:: nni.retiarii.evaluator.pytorch.DataLoader
nni.retiarii.execution
----------------------
......
......@@ -52,3 +52,13 @@ nav.md-tabs .md-tabs__item:not(:last-child) .md-tabs__link:after {
.citation dt {
padding-right: 1em;
}
/* fixes reference overlapping issue */
/* This is originally defined to be negative in application_fixes.css */
/* They did that to ensure the header doesn't disappear in jump links */
/* We did this by using scroll-margin-top instead */
dt:target {
margin-top: 0.15rem !important;
padding-top: 0 !important;
scroll-margin-top: 3.5rem;
}
......@@ -5,6 +5,7 @@ import copy
import functools
import inspect
import numbers
import os
import types
import warnings
from io import IOBase
......@@ -235,6 +236,13 @@ def trace(cls_or_func: T = None, *, kw_only: bool = True) -> Union[T, Traceable]
pass
"""
# This is an internal flag to control the behavior of trace.
# Useful in doc build and tests.
# Might be changed in future.
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls_or_func
def wrap(cls_or_func):
# already annotated, do nothing
if getattr(cls_or_func, '_traced', False):
......
......@@ -4,7 +4,7 @@
import os
import warnings
from pathlib import Path
from typing import Dict, Union, Optional, List, Type
from typing import Dict, Union, Optional, List, Callable
import pytorch_lightning as pl
import torch.nn as nn
......@@ -29,11 +29,20 @@ __all__ = ['LightningModule', 'Trainer', 'DataLoader', 'Lightning', 'Classificat
class LightningModule(pl.LightningModule):
"""
Basic wrapper of generated model.
Lightning modules used in NNI should inherit this class.
It's a subclass of ``pytorch_lightning.LightningModule``.
See https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html
"""
def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> None:
def set_model(self, model: Union[Callable[[], nn.Module], nn.Module]) -> None:
"""Set the inner model (architecture) to train / evaluate.
Parameters
----------
model : callable or nn.Module
Can be a callable returning nn.Module or nn.Module.
"""
if isinstance(model, nn.Module):
self.model = model
else:
......@@ -41,9 +50,13 @@ class LightningModule(pl.LightningModule):
Trainer = nni.trace(pl.Trainer)
Trainer.__doc__ = 'Traced version of ``pytorch_lightning.Trainer``.'
Trainer.__doc__ = """
Traced version of ``pytorch_lightning.Trainer``. See https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
"""
DataLoader = nni.trace(torch_data.DataLoader)
DataLoader.__doc__ = 'Traced version of ``torch.utils.data.DataLoader``.'
DataLoader.__doc__ = """
Traced version of ``torch.utils.data.DataLoader``. See https://pytorch.org/docs/stable/data.html
"""
@nni.trace
class Lightning(Evaluator):
......@@ -238,7 +251,7 @@ class _ClassificationModule(_SupervisedLearningModule):
class Classification(Lightning):
"""
Trainer that is used for classification.
Evaluator that is used for classification.
Parameters
----------
......@@ -291,7 +304,7 @@ class _RegressionModule(_SupervisedLearningModule):
class Regression(Lightning):
"""
Trainer that is used for regression.
Evaluator that is used for regression.
Parameters
----------
......
......@@ -567,7 +567,8 @@ class ValueChoiceX(Translatable):
def __index__(self) -> NoReturn:
# https://docs.python.org/3/reference/datamodel.html#object.__index__
raise RuntimeError("`__index__` is not allowed on ValueChoice, which means you can't "
"use int(), float(), complex(), range() on a ValueChoice.")
"use int(), float(), complex(), range() on a ValueChoice. "
"To cast the type of ValueChoice, please try `ValueChoice.to_int()` or `ValueChoice.to_float()`.")
def __bool__(self) -> NoReturn:
raise RuntimeError('Cannot use bool() on ValueChoice. That means, using ValueChoice in a if-clause is illegal. '
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import inspect
import os
import warnings
from typing import Any, TypeVar, Union
......@@ -64,6 +65,12 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
class PrimitiveOp(nn.Module):
...
"""
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
_check_wrapped(cls)
import torch.nn as nn
......@@ -103,6 +110,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
But in future, we might enforce ``@model_wrapper`` to be required for base model.
"""
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
_check_wrapped(cls)
import torch.nn as nn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment