Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mobilenetv3 import MobileNetV3Space
from .nasbench101 import NasBench101
from .nasbench201 import NasBench201
from .nasnet import NDS, NASNet, ENAS, AmoebaNet, PNAS, DARTS
from .proxylessnas import ProxylessNAS
from .shufflenet import ShuffleNetSpace
from .autoformer import AutoformerSpace
\ No newline at end of file
...@@ -7,12 +7,12 @@ import torch ...@@ -7,12 +7,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath from timm.models.layers import trunc_normal_, DropPath
import nni.retiarii.nn.pytorch as nn import nni.nas.nn.pytorch as nn
from nni.retiarii import model_wrapper, basic_unit from nni.nas import model_wrapper, basic_unit
from nni.retiarii.nn.pytorch.api import ValueChoiceX from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.retiarii.oneshot.pytorch.supermodule.operation import MixedOperation from nni.nas.oneshot.pytorch.supermodule.operation import MixedOperation
from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import traverse_all_options
from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as _S, MaybeWeighted as _W from nni.nas.oneshot.pytorch.supermodule._operation_utils import Slicable as _S, MaybeWeighted as _W
from .utils.fixed import FixedFactory from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight from .utils.pretrained import load_pretrained_weight
......
...@@ -5,8 +5,8 @@ from functools import partial ...@@ -5,8 +5,8 @@ from functools import partial
from typing import Tuple, Optional, Callable, Union, List, Type, cast from typing import Tuple, Optional, Callable, Union, List, Type, cast
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.nas.nn.pytorch as nn
from nni.retiarii import model_wrapper from nni.nas import model_wrapper
from nni.typehint import Literal from nni.typehint import Literal
from .proxylessnas import ConvBNReLU, InvertedResidual, DepthwiseSeparableConv, make_divisible, reset_parameters from .proxylessnas import ConvBNReLU, InvertedResidual, DepthwiseSeparableConv, make_divisible, reset_parameters
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Famous building blocks of search spaces."""
from .autoactivation import *
from .nasbench101 import *
from .nasbench201 import *
...@@ -7,10 +7,10 @@ from packaging.version import Version ...@@ -7,10 +7,10 @@ from packaging.version import Version
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.serializer import basic_unit from nni.nas.utils import basic_unit
from .api import LayerChoice from nni.nas.nn.pytorch import LayerChoice
from .mutation_utils import generate_new_label from nni.nas.nn.pytorch.mutation_utils import generate_new_label
__all__ = ['AutoActivation'] __all__ = ['AutoActivation']
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
__all__ = ['NasBench101Cell', 'NasBench101Mutator']
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import Callable, List, Optional, Union, Dict, Tuple, cast from typing import Callable, List, Optional, Union, Dict, Tuple, cast
...@@ -9,10 +11,10 @@ import numpy as np ...@@ -9,10 +11,10 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.mutator import InvalidMutation, Mutator from nni.nas.mutable import InvalidMutation, Mutator
from nni.retiarii.graph import Model from nni.nas.execution.common import Model
from .api import InputChoice, ValueChoice, LayerChoice from nni.nas.nn.pytorch import InputChoice, ValueChoice, LayerChoice
from .mutation_utils import Mutable, generate_new_label, get_fixed_dict from nni.nas.nn.pytorch.mutation_utils import Mutable, generate_new_label, get_fixed_dict
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -5,8 +5,9 @@ import math ...@@ -5,8 +5,9 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii import model_wrapper
from nni.retiarii.nn.pytorch import NasBench101Cell from nni.nas import model_wrapper
from .modules.nasbench101 import NasBench101Cell
__all__ = ['NasBench101'] __all__ = ['NasBench101']
......
...@@ -6,8 +6,8 @@ from typing import Callable, Dict ...@@ -6,8 +6,8 @@ from typing import Callable, Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii import model_wrapper from nni.nas import model_wrapper
from nni.retiarii.nn.pytorch import NasBench201Cell from .modules.nasbench201 import NasBench201Cell
__all__ = ['NasBench201'] __all__ = ['NasBench201']
......
...@@ -18,11 +18,11 @@ except ImportError: ...@@ -18,11 +18,11 @@ except ImportError:
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.nas.nn.pytorch as nn
from nni.retiarii import model_wrapper from nni.nas import model_wrapper
from nni.retiarii.oneshot.pytorch.supermodule.sampling import PathSamplingRepeat from nni.nas.oneshot.pytorch.supermodule.sampling import PathSamplingRepeat
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import DifferentiableMixedRepeat from nni.nas.oneshot.pytorch.supermodule.differentiable import DifferentiableMixedRepeat
from .utils.fixed import FixedFactory from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight from .utils.pretrained import load_pretrained_weight
...@@ -394,7 +394,7 @@ class NDSStagePathSampling(PathSamplingRepeat): ...@@ -394,7 +394,7 @@ class NDSStagePathSampling(PathSamplingRepeat):
"""The path-sampling implementation (for one-shot) of each NDS stage if depth is mutating.""" """The path-sampling implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod @classmethod
def mutate(cls, module, name, memo, mutate_kwargs): def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.api.ValueChoiceX): if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
return cls( return cls(
module.first_cell_transformation_factory(), module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks), cast(List[nn.Module], module.blocks),
...@@ -419,7 +419,7 @@ class NDSStageDifferentiable(DifferentiableMixedRepeat): ...@@ -419,7 +419,7 @@ class NDSStageDifferentiable(DifferentiableMixedRepeat):
"""The differentiable implementation (for one-shot) of each NDS stage if depth is mutating.""" """The differentiable implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod @classmethod
def mutate(cls, module, name, memo, mutate_kwargs): def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.api.ValueChoiceX): if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
# Only interesting when depth is mutable # Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls( return cls(
......
...@@ -5,8 +5,8 @@ import math ...@@ -5,8 +5,8 @@ import math
from typing import Optional, Callable, List, Tuple, Iterator, Union, cast, overload from typing import Optional, Callable, List, Tuple, Iterator, Union, cast, overload
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.nas.nn.pytorch as nn
from nni.retiarii import model_wrapper from nni.nas import model_wrapper
from .utils.fixed import FixedFactory from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight from .utils.pretrained import load_pretrained_weight
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
from typing import cast from typing import cast
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.nas.nn.pytorch as nn
from nni.retiarii import model_wrapper from nni.nas import model_wrapper
from .utils.fixed import FixedFactory from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight from .utils.pretrained import load_pretrained_weight
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
"""This file should be merged to nni/retiarii/fixed.py""" """This file should be merged to nni/nas/fixed.py"""
from typing import Type from typing import Type
from nni.retiarii.utils import ContextStack from nni.nas.utils import ContextStack
class FixedFactory: class FixedFactory:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import *
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import warnings import warnings
from typing import (Any, Iterable, List, Optional, Tuple, cast) from typing import (Any, Iterable, List, Optional, Tuple, cast)
from .graph import Model, Mutation, ModelStatus from nni.nas.execution import Model, Mutation, ModelStatus
__all__ = ['Sampler', 'Mutator', 'InvalidMutation'] __all__ = ['Sampler', 'Mutator', 'InvalidMutation']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .choice import *
from .repeat import *
from .cell import *
from .layers import *
...@@ -12,8 +12,8 @@ except ImportError: ...@@ -12,8 +12,8 @@ except ImportError:
import torch import torch
import torch.nn as nn import torch.nn as nn
from .api import ChosenInputs, LayerChoice, InputChoice from .choice import ChosenInputs, LayerChoice, InputChoice
from .nn import ModuleList # pylint: disable=no-name-in-module from .layers import ModuleList # pylint: disable=no-name-in-module
from .mutation_utils import generate_new_label from .mutation_utils import generate_new_label
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment