"...git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "768ab4df541275c05eec5ee5db2f89661302610d"
Unverified Commit d6a49755 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

NAS documentation improvements (v2.9) (#5077)

parent 808b0655
...@@ -5,7 +5,6 @@ Advanced Usage ...@@ -5,7 +5,6 @@ Advanced Usage
:maxdepth: 2 :maxdepth: 2
execution_engine execution_engine
space_hub
hardware_aware_nas hardware_aware_nas
mutator mutator
customize_strategy customize_strategy
......
...@@ -58,7 +58,11 @@ One way to use the model space is to directly leverage the searched results. Not ...@@ -58,7 +58,11 @@ One way to use the model space is to directly leverage the searched results. Not
.. code-block:: python .. code-block:: python
import torch
from nni.retiarii.hub.pytorch import MobileNetV3Space from nni.retiarii.hub.pytorch import MobileNetV3Space
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageNet
# Load one of the searched results from MobileNetV3 search space. # Load one of the searched results from MobileNetV3 search space.
mobilenetv3 = MobileNetV3Space.load_searched_model( mobilenetv3 = MobileNetV3Space.load_searched_model(
...@@ -67,11 +71,18 @@ One way to use the model space is to directly leverage the searched results. Not ...@@ -67,11 +71,18 @@ One way to use the model space is to directly leverage the searched results. Not
) )
# MobileNetV3 model can be directly evaluated on ImageNet # MobileNetV3 model can be directly evaluated on ImageNet
dataset = ImageNet(directory, 'val', transform=test_transform) transform = transforms.Compose([
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = ImageNet('/path/to/your/imagenet', 'val', transform=transform)
dataloader = DataLoader(dataset, batch_size=64)
mobilenetv3.eval() mobilenetv3.eval()
with torch.no_grad(): with torch.no_grad():
correct = total = 0 correct = total = 0
for inputs, targets in pbar: for inputs, targets in dataloader:
logits = mobilenetv3(inputs) logits = mobilenetv3(inputs)
_, predict = torch.max(logits, 1) _, predict = torch.max(logits, 1)
correct += (predict == targets).sum().item() correct += (predict == targets).sum().item()
......
...@@ -7,6 +7,7 @@ Neural Architecture Search ...@@ -7,6 +7,7 @@ Neural Architecture Search
overview overview
Tutorials <tutorials> Tutorials <tutorials>
construct_space construct_space
space_hub
exploration_strategy exploration_strategy
evaluator evaluator
advanced_usage advanced_usage
...@@ -13,12 +13,18 @@ Classification ...@@ -13,12 +13,18 @@ Classification
.. autoclass:: nni.retiarii.evaluator.pytorch.Classification .. autoclass:: nni.retiarii.evaluator.pytorch.Classification
:members: :members:
.. autoclass:: nni.retiarii.evaluator.pytorch.ClassificationModule
:members:
Regression Regression
---------- ----------
.. autoclass:: nni.retiarii.evaluator.pytorch.Regression .. autoclass:: nni.retiarii.evaluator.pytorch.Regression
:members: :members:
.. autoclass:: nni.retiarii.evaluator.pytorch.RegressionModule
:members:
Utilities Utilities
--------- ---------
......
...@@ -95,6 +95,15 @@ NASNet ...@@ -95,6 +95,15 @@ NASNet
.. autoclass:: nni.retiarii.hub.pytorch.nasnet.NDS .. autoclass:: nni.retiarii.hub.pytorch.nasnet.NDS
:members: :members:
.. autoclass:: nni.retiarii.hub.pytorch.nasnet.NDSStage
:members:
.. autoclass:: nni.retiarii.hub.pytorch.nasnet.NDSStagePathSampling
:members:
.. autoclass:: nni.retiarii.hub.pytorch.nasnet.NDSStageDifferentiable
:members:
ENAS ENAS
^^^^ ^^^^
......
...@@ -298,6 +298,13 @@ class Classification(Lightning): ...@@ -298,6 +298,13 @@ class Classification(Lightning):
""" """
Evaluator that is used for classification. Evaluator that is used for classification.
Available callback metrics in :class:`Classification` are:
- train_loss
- train_acc
- val_loss
- val_acc
Parameters Parameters
---------- ----------
criterion : nn.Module criterion : nn.Module
...@@ -367,6 +374,13 @@ class Regression(Lightning): ...@@ -367,6 +374,13 @@ class Regression(Lightning):
""" """
Evaluator that is used for regression. Evaluator that is used for regression.
Available callback metrics in :class:`Regression` are:
- train_loss
- train_mse
- val_loss
- val_mse
Parameters Parameters
---------- ----------
criterion : nn.Module criterion : nn.Module
......
...@@ -5,7 +5,6 @@ from typing import Optional, Tuple, cast, Any, Dict, Union ...@@ -5,7 +5,6 @@ from typing import Optional, Tuple, cast, Any, Dict, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
import nni.nas.nn.pytorch as nn import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper, basic_unit from nni.nas import model_wrapper, basic_unit
...@@ -17,6 +16,12 @@ from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as _S, ...@@ -17,6 +16,12 @@ from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as _S,
from .utils.fixed import FixedFactory from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight from .utils.pretrained import load_pretrained_weight
try:
TIMM_INSTALLED = True
from timm.models.layers import trunc_normal_, DropPath
except ImportError:
TIMM_INSTALLED = False
class RelativePosition2D(nn.Module): class RelativePosition2D(nn.Module):
def __init__(self, head_embed_dim, length=14,) -> None: def __init__(self, head_embed_dim, length=14,) -> None:
...@@ -355,6 +360,10 @@ class AutoformerSpace(nn.Module): ...@@ -355,6 +360,10 @@ class AutoformerSpace(nn.Module):
rpe: bool = True, rpe: bool = True,
): ):
super().__init__() super().__init__()
if not TIMM_INSTALLED:
raise ImportError('timm must be installed to use AutoFormer.')
# define search space parameters # define search space parameters
embed_dim = nn.ValueChoice(list(search_embed_dim), label="embed_dim") embed_dim = nn.ValueChoice(list(search_embed_dim), label="embed_dim")
depth = nn.ValueChoice(list(search_depth), label="depth") depth = nn.ValueChoice(list(search_depth), label="depth")
......
...@@ -66,6 +66,25 @@ class NasBench101(nn.Module): ...@@ -66,6 +66,25 @@ class NasBench101(nn.Module):
"""The full search space proposed by `NAS-Bench-101 <http://proceedings.mlr.press/v97/ying19a/ying19a.pdf>`__. """The full search space proposed by `NAS-Bench-101 <http://proceedings.mlr.press/v97/ying19a/ying19a.pdf>`__.
It's simply a stack of :class:`~nni.retiarii.nn.pytorch.NasBench101Cell`. Operations are conv3x3, conv1x1 and maxpool respectively. It's simply a stack of :class:`~nni.retiarii.nn.pytorch.NasBench101Cell`. Operations are conv3x3, conv1x1 and maxpool respectively.
Parameters
----------
stem_out_channels
Number of output channels of the stem convolution.
num_stacks
Number of stacks in the network.
num_modules_per_stack
Number of modules in each stack. Each module is a :class:`~nni.retiarii.nn.pytorch.NasBench101Cell`.
max_num_vertices
Maximum number of vertices in each cell.
max_num_edges
Maximum number of edges in each cell.
num_labels
Number of categories for classification.
bn_eps
Epsilon for batch normalization.
bn_momentum
Momentum for batch normalization.
""" """
def __init__(self, def __init__(self,
......
...@@ -154,6 +154,15 @@ class NasBench201(nn.Module): ...@@ -154,6 +154,15 @@ class NasBench201(nn.Module):
"""The full search space proposed by `NAS-Bench-201 <https://arxiv.org/abs/2001.00326>`__. """The full search space proposed by `NAS-Bench-201 <https://arxiv.org/abs/2001.00326>`__.
It's a stack of :class:`~nni.retiarii.nn.pytorch.NasBench201Cell`. It's a stack of :class:`~nni.retiarii.nn.pytorch.NasBench201Cell`.
Parameters
----------
stem_out_channels
The output channels of the stem.
num_modules_per_stack
The number of modules (cells) in each stack. Each cell is a :class:`~nni.retiarii.nn.pytorch.NasBench201Cell`.
num_labels
Number of categories for classification.
""" """
def __init__(self, def __init__(self,
stem_out_channels: int = 16, stem_out_channels: int = 16,
......
...@@ -439,6 +439,19 @@ class NDSStageDifferentiable(DifferentiableMixedRepeat): ...@@ -439,6 +439,19 @@ class NDSStageDifferentiable(DifferentiableMixedRepeat):
_INIT_PARAMETER_DOCS = """ _INIT_PARAMETER_DOCS = """
Notes
-----
To use NDS spaces with one-shot strategies,
especially when depth is mutating (i.e., ``num_cells`` is set to a tuple / list),
please use :class:`~nni.retiarii.hub.pytorch.nasnet.NDSStagePathSampling` (with ENAS and RandomOneShot)
and :class:`~nni.retiarii.hub.pytorch.nasnet.NDSStageDifferentiable` (with DARTS and Proxyless) into ``mutation_hooks``.
This is because the output shape of each stacked block in :class:`~nni.retiarii.hub.pytorch.nasnet.NDSStage` can be different.
For example::
from nni.retiarii.hub.pytorch.nasnet import NDSStageDifferentiable
darts_strategy = strategy.DARTS(mutation_hooks=[NDSStageDifferentiable.mutate])
Parameters Parameters
---------- ----------
width width
...@@ -451,6 +464,8 @@ _INIT_PARAMETER_DOCS = """ ...@@ -451,6 +464,8 @@ _INIT_PARAMETER_DOCS = """
auxiliary_loss auxiliary_loss
If true, another auxiliary classification head will produce the another prediction. If true, another auxiliary classification head will produce the another prediction.
This makes the output of network two logits in the training phase. This makes the output of network two logits in the training phase.
drop_path_prob
Apply drop path. Enabled when it's set to be greater than 0.
""" """
...@@ -475,8 +490,6 @@ class NDS(nn.Module): ...@@ -475,8 +490,6 @@ class NDS(nn.Module):
See :class:`~nni.retiarii.nn.pytorch.Cell`. See :class:`~nni.retiarii.nn.pytorch.Cell`.
num_nodes_per_cell num_nodes_per_cell
See :class:`~nni.retiarii.nn.pytorch.Cell`. See :class:`~nni.retiarii.nn.pytorch.Cell`.
drop_path_prob : float
Apply drop path. Enabled when it's set to be greater than 0.
""" """
def __init__(self, def __init__(self,
......
...@@ -157,6 +157,25 @@ class InvertedResidual(nn.Sequential): ...@@ -157,6 +157,25 @@ class InvertedResidual(nn.Sequential):
- https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L453 - https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L453
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L134 - https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L134
Parameters
----------
in_channels
The number of input channels. Can be a value choice.
out_channels
The number of output channels. Can be a value choice.
expand_ratio
The ratio of intermediate channels with respect to input channels. Can be a value choice.
kernel_size
The kernel size of the depthwise convolution. Can be a value choice.
stride
The stride of the depthwise convolution.
squeeze_excite
Callable to create squeeze and excitation layer. Take hidden channels and input channels as arguments.
norm_layer
Callable to create normalization layer. Take input channels as argument.
activation_layer
Callable to create activation layer. No input arguments.
""" """
def __init__( def __init__(
...@@ -252,6 +271,21 @@ class ProxylessNAS(nn.Module): ...@@ -252,6 +271,21 @@ class ProxylessNAS(nn.Module):
We note that :class:`MobileNetV3Space` is different in this perspective. We note that :class:`MobileNetV3Space` is different in this perspective.
This space can be implemented as part of :class:`MobileNetV3Space`, but we separate those following conventions. This space can be implemented as part of :class:`MobileNetV3Space`, but we separate those following conventions.
Parameters
----------
num_labels
The number of labels for classification.
base_widths
Widths of each stage, from stem, to body, to head. Length should be 9.
dropout_rate
Dropout rate for the final classification layer.
width_mult
Width multiplier for the model.
bn_eps
Epsilon for batch normalization.
bn_momentum
Momentum for batch normalization.
""" """
def __init__(self, num_labels: int = 1000, def __init__(self, num_labels: int = 1000,
......
...@@ -33,6 +33,12 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -33,6 +33,12 @@ class DartsLightningModule(BaseOneShotLightningModule):
The current implementation corresponds to DARTS (1st order) in paper. The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet. Second order (unrolled 2nd-order derivatives) is not supported yet.
.. note::
DARTS is running a weighted sum of possible architectures under the hood.
Please bear in mind that it will be slower and consume more memory that training a single architecture.
The common practice is to down-scale the network (e.g., smaller depth / width) for speedup.
.. versionadded:: 2.8 .. versionadded:: 2.8
Supports searching for ValueChoices on operations, with the technique described in Supports searching for ValueChoices on operations, with the technique described in
...@@ -215,6 +221,12 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -215,6 +221,12 @@ class GumbelDartsLightningModule(DartsLightningModule):
* :class:`nni.retiarii.nn.pytorch.Cell`. * :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`. * :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
.. note::
GumbelDARTS is running a weighted sum of possible architectures under the hood.
Please bear in mind that it will be slower and consume more memory that training a single architecture.
The common practice is to down-scale the network (e.g., smaller depth / width) for speedup.
{{module_notes}} {{module_notes}}
{optimization_note} {optimization_note}
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
import logging
from typing import Any, cast, Dict from typing import Any, cast, Dict
import pytorch_lightning as pl import pytorch_lightning as pl
...@@ -13,13 +14,17 @@ import torch.nn as nn ...@@ -13,13 +14,17 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from .base_lightning import MANUAL_OPTIMIZATION_NOTE, BaseOneShotLightningModule, MutationHook, no_default_hook from .base_lightning import MANUAL_OPTIMIZATION_NOTE, BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.base import sub_state_dict
from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES
from .supermodule.sampling import ( from .supermodule.sampling import (
PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy, PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy,
PathSamplingCell, PathSamplingRepeat PathSamplingCell, PathSamplingRepeat
) )
from .enas import ReinforceController, ReinforceField from .enas import ReinforceController, ReinforceField
from .supermodule.base import sub_state_dict
_logger = logging.getLogger(__name__)
class RandomSamplingLightningModule(BaseOneShotLightningModule): class RandomSamplingLightningModule(BaseOneShotLightningModule):
_random_note = """ _random_note = """
...@@ -141,7 +146,7 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -141,7 +146,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
- Firstly, training model parameters. - Firstly, training model parameters.
- Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward. - Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward.
.. note:: .. attention::
ENAS requires the evaluator to report metrics via ``self.log`` in its ``validation_step``. ENAS requires the evaluator to report metrics via ``self.log`` in its ``validation_step``.
See explanation of ``reward_metric_name`` for details. See explanation of ``reward_metric_name`` for details.
...@@ -214,6 +219,13 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -214,6 +219,13 @@ class EnasLightningModule(RandomSamplingLightningModule):
mutation_hooks: list[MutationHook] | None = None): mutation_hooks: list[MutationHook] | None = None):
super().__init__(inner_module, mutation_hooks) super().__init__(inner_module, mutation_hooks)
if reward_metric_name is None:
_logger.warning(
'It is strongly recommended to have `reward_metric_name` specified. '
'It should be one of the metrics logged in `self.log` in evaluator. '
'Otherwise it will infer the reward based on certain rules.'
)
# convert parameter spec to legacy ReinforceField # convert parameter spec to legacy ReinforceField
# this part will be refactored # this part will be refactored
self.nas_fields: list[ReinforceField] = [] self.nas_fields: list[ReinforceField] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment