Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mobilenetv3 import MobileNetV3Space
from .nasbench101 import NasBench101
from .nasbench201 import NasBench201
from .nasnet import NDS, NASNet, ENAS, AmoebaNet, PNAS, DARTS
from .proxylessnas import ProxylessNAS
from .shufflenet import ShuffleNetSpace
from .autoformer import AutoformerSpace
\ No newline at end of file
......@@ -7,12 +7,12 @@ import torch
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper, basic_unit
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.oneshot.pytorch.supermodule.operation import MixedOperation
from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as _S, MaybeWeighted as _W
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper, basic_unit
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.oneshot.pytorch.supermodule.operation import MixedOperation
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as _S, MaybeWeighted as _W
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
......
......@@ -5,8 +5,8 @@ from functools import partial
from typing import Tuple, Optional, Callable, Union, List, Type, cast
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from nni.typehint import Literal
from .proxylessnas import ConvBNReLU, InvertedResidual, DepthwiseSeparableConv, make_divisible, reset_parameters
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Famous building blocks of search spaces."""
from .autoactivation import *
from .nasbench101 import *
from .nasbench201 import *
......@@ -7,10 +7,10 @@ from packaging.version import Version
import torch
import torch.nn as nn
from nni.retiarii.serializer import basic_unit
from nni.nas.utils import basic_unit
from .api import LayerChoice
from .mutation_utils import generate_new_label
from nni.nas.nn.pytorch import LayerChoice
from nni.nas.nn.pytorch.mutation_utils import generate_new_label
__all__ = ['AutoActivation']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['NasBench101Cell', 'NasBench101Mutator']
import logging
from collections import OrderedDict
from typing import Callable, List, Optional, Union, Dict, Tuple, cast
......@@ -9,10 +11,10 @@ import numpy as np
import torch
import torch.nn as nn
from nni.retiarii.mutator import InvalidMutation, Mutator
from nni.retiarii.graph import Model
from .api import InputChoice, ValueChoice, LayerChoice
from .mutation_utils import Mutable, generate_new_label, get_fixed_dict
from nni.nas.mutable import InvalidMutation, Mutator
from nni.nas.execution.common import Model
from nni.nas.nn.pytorch import InputChoice, ValueChoice, LayerChoice
from nni.nas.nn.pytorch.mutation_utils import Mutable, generate_new_label, get_fixed_dict
_logger = logging.getLogger(__name__)
......
......@@ -5,8 +5,9 @@ import math
import torch
import torch.nn as nn
from nni.retiarii import model_wrapper
from nni.retiarii.nn.pytorch import NasBench101Cell
from nni.nas import model_wrapper
from .modules.nasbench101 import NasBench101Cell
__all__ = ['NasBench101']
......
......@@ -6,8 +6,8 @@ from typing import Callable, Dict
import torch
import torch.nn as nn
from nni.retiarii import model_wrapper
from nni.retiarii.nn.pytorch import NasBench201Cell
from nni.nas import model_wrapper
from .modules.nasbench201 import NasBench201Cell
__all__ = ['NasBench201']
......
......@@ -18,11 +18,11 @@ except ImportError:
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from nni.retiarii.oneshot.pytorch.supermodule.sampling import PathSamplingRepeat
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import DifferentiableMixedRepeat
from nni.nas.oneshot.pytorch.supermodule.sampling import PathSamplingRepeat
from nni.nas.oneshot.pytorch.supermodule.differentiable import DifferentiableMixedRepeat
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
......@@ -394,7 +394,7 @@ class NDSStagePathSampling(PathSamplingRepeat):
"""The path-sampling implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.api.ValueChoiceX):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
return cls(
module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks),
......@@ -419,7 +419,7 @@ class NDSStageDifferentiable(DifferentiableMixedRepeat):
"""The differentiable implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.api.ValueChoiceX):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(
......
......@@ -5,8 +5,8 @@ import math
from typing import Optional, Callable, List, Tuple, Iterator, Union, cast, overload
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
......
......@@ -4,8 +4,8 @@
from typing import cast
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""This file should be merged to nni/retiarii/fixed.py"""
"""This file should be merged to nni/nas/fixed.py"""
from typing import Type
from nni.retiarii.utils import ContextStack
from nni.nas.utils import ContextStack
class FixedFactory:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import *
......@@ -4,7 +4,7 @@
import warnings
from typing import (Any, Iterable, List, Optional, Tuple, cast)
from .graph import Model, Mutation, ModelStatus
from nni.nas.execution import Model, Mutation, ModelStatus
__all__ = ['Sampler', 'Mutator', 'InvalidMutation']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .choice import *
from .repeat import *
from .cell import *
from .layers import *
......@@ -12,8 +12,8 @@ except ImportError:
import torch
import torch.nn as nn
from .api import ChosenInputs, LayerChoice, InputChoice
from .nn import ModuleList # pylint: disable=no-name-in-module
from .choice import ChosenInputs, LayerChoice, InputChoice
from .layers import ModuleList # pylint: disable=no-name-in-module
from .mutation_utils import generate_new_label
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment