Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
......@@ -13,9 +13,7 @@ import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import (STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace,
NoContextError)
from nni.nas.utils import STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace, NoContextError, basic_unit
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
......
......@@ -3,11 +3,11 @@
from pathlib import Path
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path = Path(__file__).parent / '_nn.py'
# To make auto-completion happy, we generate a _layers.py that lists out all the classes.
nn_cache_file_path = Path(__file__).parent / '_layers.py'
# Update this when cache format changes, to enforce an update.
cache_version = 2
cache_version = 3
def validate_cache() -> bool:
......@@ -70,7 +70,7 @@ def generate_stub_file() -> str:
f'# _torch_nn_cache_version = {cache_version}',
'import typing',
'import torch.nn as nn',
'from nni.retiarii.serializer import basic_unit',
'from nni.nas.utils import basic_unit',
]
all_names = []
......@@ -113,4 +113,4 @@ if not validate_cache():
del Path, validate_cache, write_cache, cache_version, nn_cache_file_path, code
from ._nn import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import
from ._layers import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import
......@@ -4,7 +4,7 @@
from typing import Any, Optional, Tuple, Union
import torch.nn as nn
from nni.retiarii.utils import NoContextError, ModelNamespace, get_current_context
from nni.nas.utils import NoContextError, ModelNamespace, get_current_context
class Mutable(nn.Module):
......
......@@ -7,13 +7,13 @@ from typing import Any, List, Optional, Tuple, Dict, Iterator, Iterable, cast
import torch.nn as nn
from nni.common.serializer import is_traceable, is_wrapped_with_trace
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator
from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import ModelNamespace, uid
from nni.nas.execution.common.graph import Graph, Model, ModelStatus, Node, Evaluator
from nni.nas.execution.common.graph_op import Cell
from nni.nas.hub.pytorch.modules import NasBench101Cell, NasBench101Mutator
from nni.nas.mutable import Mutator
from nni.nas.utils import is_basic_unit, is_model_wrapped, ModelNamespace, uid
from .api import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder
from .component import NasBench101Cell, NasBench101Mutator
from .choice import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder
class LayerChoiceMutator(Mutator):
......@@ -60,7 +60,7 @@ class InputChoiceMutator(Mutator):
chosen = [self.choice(candidates) for _ in range(n_chosen)]
for node in self.nodes:
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs',
target.update_operation('__torch__.nni.nas.nn.pytorch.ChosenInputs',
{'chosen': chosen, 'reduction': node.operation.parameters['reduction']})
......@@ -171,7 +171,7 @@ class RepeatMutator(Mutator):
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = []
ic_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.InputChoice'))
ic_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.nas.nn.pytorch.choice.InputChoice'))
for node_list in ic_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['n_candidates'], node_list)) and \
_is_all_equal(map(lambda node: node.operation.parameters['n_chosen'], node_list)), \
......@@ -179,7 +179,7 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = InputChoiceMutator(node_list)
applied_mutators.append(mutator)
vc_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.ValueChoice'))
vc_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.nas.nn.pytorch.choice.ValueChoice'))
for node_list in vc_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['candidates'], node_list)), \
'Value choice with the same label must have the same candidates.'
......
......@@ -3,9 +3,9 @@
from __future__ import annotations
import tensorflow as tf
from tensorflow.keras import Layer
class LayerChoice(tf.keras.Layer):
class LayerChoice(Layer):
# FIXME: This is only a draft to test multi-framework support, it's not unimplemented at all.
pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule
......@@ -14,10 +14,10 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import nni.retiarii.nn.pytorch as nas_nn
import nni.nas.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.typehint import Literal
from .supermodule.base import BaseSuperNetModule
......@@ -122,7 +122,7 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k
nas_nn.LayerChoice,
nas_nn.InputChoice,
nas_nn.Repeat,
nas_nn.NasBench101Cell,
# nas_nn.NasBench101Cell, # FIXME: nasbench101 is moved to hub, can't check any more.
# nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
......@@ -156,8 +156,8 @@ class BaseOneShotLightningModule(pl.LightningModule):
Extra mutation hooks to support customized mutation on primitives other than built-ins.
Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.nas.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
......@@ -177,21 +177,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
The returned arguments can be also one of the three kinds:
1. tuple of: :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
1. tuple of: :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
2. boolean,
3. :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
3. :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the following hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false.
If a none value appears on the place of
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
it means the hook suggests to
keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.no_default_hook`.
An example of mutation hook is given in :func:`~nni.nas.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list.
"""
......@@ -309,7 +309,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
Combine architecture optimizers and user's model optimizers.
You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm.
For now :attr:`model` is tested against evaluators in :mod:`nni.retiarii.evaluator.pytorch.lightning`
For now :attr:`model` is tested against evaluators in :mod:`nni.nas.evaluator.pytorch.lightning`
and it only returns 1 optimizer.
But for extendibility, codes for other return value types are also implemented.
"""
......
......@@ -4,8 +4,8 @@
"""Strategy integration of one-shot.
This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.retiarii.strategy``.
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
For consistency, please consider importing strategies from ``nni.nas.strategy``.
For example, ``nni.nas.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
"""
......@@ -17,9 +17,9 @@ from typing import Any, Type
import torch.nn as nn
from nni.retiarii.graph import Model
from nni.retiarii.strategy.base import BaseStrategy
from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule
from nni.nas.execution.common import Model
from nni.nas.strategy.base import BaseStrategy
from nni.nas.evaluator.pytorch.lightning import Lightning, LightningModule
from .base_lightning import BaseOneShotLightningModule
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
......
......@@ -14,7 +14,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.retiarii.nn.pytorch import ValueChoice
from nni.nas.nn.pytorch import ValueChoice
class DifferentiableSuperConv2d(nn.Conv2d):
......
......@@ -13,7 +13,7 @@ import numpy as np
import torch
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch.api import ChoiceOf, ValueChoiceX
from nni.nas.nn.pytorch.choice import ChoiceOf, ValueChoiceX
Choice = Any
......
......@@ -14,9 +14,9 @@ import torch.nn as nn
import torch.nn.functional as F
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs
from nni.nas.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy
......@@ -28,7 +28,7 @@ _logger = logging.getLogger(__name__)
__all__ = [
'DifferentiableMixedLayer', 'DifferentiableMixedInput',
'DifferentiableMixedRepeat', 'DifferentiableMixedCell',
'MixedOpDifferentiablePolicy',
'MixedOpDifferentiablePolicy', 'GumbelSoftmax',
]
......
......@@ -18,10 +18,10 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import nni.retiarii.nn.pytorch as retiarii_nn
import nni.nas.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.nas.nn.pytorch.choice import ValueChoiceX
from .base import BaseSuperNetModule
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant
......@@ -236,7 +236,7 @@ class MixedLinear(MixedOperation, nn.Linear):
Prefix of weight and bias will be sliced.
"""
bound_type = retiarii_nn.Linear
bound_type = nas_nn.Linear
argument_list = ['in_features', 'out_features']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
......@@ -294,7 +294,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
□ □ □ □ □ □ □ □ □ □
"""
bound_type = retiarii_nn.Conv2d
bound_type = nas_nn.Conv2d
argument_list = [
'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'dilation', 'groups'
]
......@@ -427,7 +427,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
PyTorch BatchNorm supports a case where momentum can be none, which is not supported here.
"""
bound_type = retiarii_nn.BatchNorm2d
bound_type = nas_nn.BatchNorm2d
argument_list = ['num_features', 'eps', 'momentum']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
......@@ -488,7 +488,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
eps is required to be float.
"""
bound_type = retiarii_nn.LayerNorm
bound_type = nas_nn.LayerNorm
argument_list = ['normalized_shape', 'eps']
@staticmethod
......@@ -565,7 +565,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``.
"""
bound_type = retiarii_nn.MultiheadAttention
bound_type = nas_nn.MultiheadAttention
argument_list = ['embed_dim', 'num_heads', 'kdim', 'vdim', 'dropout']
def __post_init__(self):
......
......@@ -11,9 +11,9 @@ import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import BaseStrategy
from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution
from .hpo import TPEStrategy, TPE
from .rl import PolicyBasedRL
from .oneshot import DARTS, Proxyless, GumbelDARTS, ENAS, RandomOneShot
......@@ -21,9 +21,9 @@ from tianshou.env.worker import EnvWorker
from typing_extensions import TypedDict
from nni.nas.execution import submit_models, wait_models
from nni.nas.execution.common import ModelStatus
from .utils import get_targeted_model
from ..graph import ModelStatus
from ..execution import submit_models, wait_models
_logger = logging.getLogger(__name__)
......
......@@ -4,8 +4,8 @@
import abc
from typing import List, Any
from ..graph import Model
from ..mutator import Mutator
from nni.nas.execution.common import Model
from nni.nas.mutable import Mutator
class BaseStrategy(abc.ABC):
......
......@@ -8,7 +8,8 @@ import random
import time
from typing import Any, Dict, List, Sequence, Optional
from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted
from nni.nas.execution import submit_models, query_available_resources, budget_exhausted
from nni.nas.mutable import InvalidMutation, Sampler
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
......
......@@ -6,9 +6,10 @@ import os
import random
import string
from .. import Sampler, codegen, utils
from ..execution.base import BaseGraphData
from ..execution.utils import get_mutation_summary
from nni.nas import Sampler, utils
from nni.nas.execution.pytorch import codegen
from nni.nas.execution.pytorch.graph import BaseGraphData
from nni.nas.execution.common import get_mutation_summary
from .base import BaseStrategy
_logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment