Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
...@@ -13,9 +13,7 @@ import torch ...@@ -13,9 +13,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import Translatable from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit from nni.nas.utils import STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace, NoContextError, basic_unit
from nni.retiarii.utils import (STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace,
NoContextError)
from .mutation_utils import Mutable, generate_new_label, get_fixed_value from .mutation_utils import Mutable, generate_new_label, get_fixed_value
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
from pathlib import Path from pathlib import Path
# To make auto-completion happy, we generate a _nn.py that lists out all the classes. # To make auto-completion happy, we generate a _layers.py that lists out all the classes.
nn_cache_file_path = Path(__file__).parent / '_nn.py' nn_cache_file_path = Path(__file__).parent / '_layers.py'
# Update this when cache format changes, to enforce an update. # Update this when cache format changes, to enforce an update.
cache_version = 2 cache_version = 3
def validate_cache() -> bool: def validate_cache() -> bool:
...@@ -70,7 +70,7 @@ def generate_stub_file() -> str: ...@@ -70,7 +70,7 @@ def generate_stub_file() -> str:
f'# _torch_nn_cache_version = {cache_version}', f'# _torch_nn_cache_version = {cache_version}',
'import typing', 'import typing',
'import torch.nn as nn', 'import torch.nn as nn',
'from nni.retiarii.serializer import basic_unit', 'from nni.nas.utils import basic_unit',
] ]
all_names = [] all_names = []
...@@ -113,4 +113,4 @@ if not validate_cache(): ...@@ -113,4 +113,4 @@ if not validate_cache():
del Path, validate_cache, write_cache, cache_version, nn_cache_file_path, code del Path, validate_cache, write_cache, cache_version, nn_cache_file_path, code
from ._nn import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import from ._layers import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import torch.nn as nn import torch.nn as nn
from nni.retiarii.utils import NoContextError, ModelNamespace, get_current_context from nni.nas.utils import NoContextError, ModelNamespace, get_current_context
class Mutable(nn.Module): class Mutable(nn.Module):
......
...@@ -7,13 +7,13 @@ from typing import Any, List, Optional, Tuple, Dict, Iterator, Iterable, cast ...@@ -7,13 +7,13 @@ from typing import Any, List, Optional, Tuple, Dict, Iterator, Iterable, cast
import torch.nn as nn import torch.nn as nn
from nni.common.serializer import is_traceable, is_wrapped_with_trace from nni.common.serializer import is_traceable, is_wrapped_with_trace
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator from nni.nas.execution.common.graph import Graph, Model, ModelStatus, Node, Evaluator
from nni.retiarii.mutator import Mutator from nni.nas.execution.common.graph_op import Cell
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped from nni.nas.hub.pytorch.modules import NasBench101Cell, NasBench101Mutator
from nni.retiarii.utils import ModelNamespace, uid from nni.nas.mutable import Mutator
from nni.nas.utils import is_basic_unit, is_model_wrapped, ModelNamespace, uid
from .api import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder from .choice import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder
from .component import NasBench101Cell, NasBench101Mutator
class LayerChoiceMutator(Mutator): class LayerChoiceMutator(Mutator):
...@@ -60,7 +60,7 @@ class InputChoiceMutator(Mutator): ...@@ -60,7 +60,7 @@ class InputChoiceMutator(Mutator):
chosen = [self.choice(candidates) for _ in range(n_chosen)] chosen = [self.choice(candidates) for _ in range(n_chosen)]
for node in self.nodes: for node in self.nodes:
target = cast(Node, model.get_node_by_name(node.name)) target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs', target.update_operation('__torch__.nni.nas.nn.pytorch.ChosenInputs',
{'chosen': chosen, 'reduction': node.operation.parameters['reduction']}) {'chosen': chosen, 'reduction': node.operation.parameters['reduction']})
...@@ -171,7 +171,7 @@ class RepeatMutator(Mutator): ...@@ -171,7 +171,7 @@ class RepeatMutator(Mutator):
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = [] applied_mutators = []
ic_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.InputChoice')) ic_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.nas.nn.pytorch.choice.InputChoice'))
for node_list in ic_nodes: for node_list in ic_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['n_candidates'], node_list)) and \ assert _is_all_equal(map(lambda node: node.operation.parameters['n_candidates'], node_list)) and \
_is_all_equal(map(lambda node: node.operation.parameters['n_chosen'], node_list)), \ _is_all_equal(map(lambda node: node.operation.parameters['n_chosen'], node_list)), \
...@@ -179,7 +179,7 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: ...@@ -179,7 +179,7 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = InputChoiceMutator(node_list) mutator = InputChoiceMutator(node_list)
applied_mutators.append(mutator) applied_mutators.append(mutator)
vc_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.ValueChoice')) vc_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.nas.nn.pytorch.choice.ValueChoice'))
for node_list in vc_nodes: for node_list in vc_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['candidates'], node_list)), \ assert _is_all_equal(map(lambda node: node.operation.parameters['candidates'], node_list)), \
'Value choice with the same label must have the same candidates.' 'Value choice with the same label must have the same candidates.'
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
from __future__ import annotations from __future__ import annotations
import tensorflow as tf from tensorflow.keras import Layer
class LayerChoice(tf.keras.Layer): class LayerChoice(Layer):
# FIXME: This is only a draft to test multi-framework support, it's not unimplemented at all. # FIXME: This is only a draft to test multi-framework support, it's not unimplemented at all.
pass pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule
...@@ -14,10 +14,10 @@ from torch.optim import Optimizer ...@@ -14,10 +14,10 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
import nni.retiarii.nn.pytorch as nas_nn import nni.nas.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.typehint import Literal from nni.typehint import Literal
from .supermodule.base import BaseSuperNetModule from .supermodule.base import BaseSuperNetModule
...@@ -122,7 +122,7 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k ...@@ -122,7 +122,7 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k
nas_nn.LayerChoice, nas_nn.LayerChoice,
nas_nn.InputChoice, nas_nn.InputChoice,
nas_nn.Repeat, nas_nn.Repeat,
nas_nn.NasBench101Cell, # nas_nn.NasBench101Cell, # FIXME: nasbench101 is moved to hub, can't check any more.
# nas_nn.ValueChoice, # could be false positive # nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later # nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet # nas_nn.NasBench201Cell, # forward = supernet
...@@ -156,8 +156,8 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -156,8 +156,8 @@ class BaseOneShotLightningModule(pl.LightningModule):
Extra mutation hooks to support customized mutation on primitives other than built-ins. Extra mutation hooks to support customized mutation on primitives other than built-ins.
Mutation hooks are callable that inputs an Module and returns a Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`. :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules. They are invoked in :func:`~nni.nas.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently, For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks. the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module, The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
...@@ -177,21 +177,21 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -177,21 +177,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
The returned arguments can be also one of the three kinds: The returned arguments can be also one of the three kinds:
1. tuple of: :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean, 1. tuple of: :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
2. boolean, 2. boolean,
3. :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None. 3. :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the following hooks should be called. The boolean value is ``suppress`` indicates whether the following hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked. When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false. Without boolean value specified, it's assumed to be false.
If a none value appears on the place of If a none value appears on the place of
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`, :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
it means the hook suggests to it means the hook suggests to
keep the module unchanged, and nothing will happen. keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.no_default_hook`. An example of mutation hook is given in :func:`~nni.nas.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving However it's recommended to implement mutation hooks by deriving
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`, :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list. and add its classmethod ``mutate`` to this list.
""" """
...@@ -309,7 +309,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -309,7 +309,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
Combine architecture optimizers and user's model optimizers. Combine architecture optimizers and user's model optimizers.
You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm. You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm.
For now :attr:`model` is tested against evaluators in :mod:`nni.retiarii.evaluator.pytorch.lightning` For now :attr:`model` is tested against evaluators in :mod:`nni.nas.evaluator.pytorch.lightning`
and it only returns 1 optimizer. and it only returns 1 optimizer.
But for extendibility, codes for other return value types are also implemented. But for extendibility, codes for other return value types are also implemented.
""" """
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
"""Strategy integration of one-shot. """Strategy integration of one-shot.
This file is put here simply because it relies on "pytorch". This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.retiarii.strategy``. For consistency, please consider importing strategies from ``nni.nas.strategy``.
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course). For example, ``nni.nas.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py. When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
""" """
...@@ -17,9 +17,9 @@ from typing import Any, Type ...@@ -17,9 +17,9 @@ from typing import Any, Type
import torch.nn as nn import torch.nn as nn
from nni.retiarii.graph import Model from nni.nas.execution.common import Model
from nni.retiarii.strategy.base import BaseStrategy from nni.nas.strategy.base import BaseStrategy
from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule from nni.nas.evaluator.pytorch.lightning import Lightning, LightningModule
from .base_lightning import BaseOneShotLightningModule from .base_lightning import BaseOneShotLightningModule
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii.nn.pytorch import ValueChoice from nni.nas.nn.pytorch import ValueChoice
class DifferentiableSuperConv2d(nn.Conv2d): class DifferentiableSuperConv2d(nn.Conv2d):
......
...@@ -13,7 +13,7 @@ import numpy as np ...@@ -13,7 +13,7 @@ import numpy as np
import torch import torch
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch.api import ChoiceOf, ValueChoiceX from nni.nas.nn.pytorch.choice import ChoiceOf, ValueChoiceX
Choice = Any Choice = Any
......
...@@ -14,9 +14,9 @@ import torch.nn as nn ...@@ -14,9 +14,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat from nni.nas.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
from nni.retiarii.nn.pytorch.api import ValueChoiceX from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs from nni.nas.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy from .operation import MixedOperation, MixedOperationSamplingPolicy
...@@ -28,7 +28,7 @@ _logger = logging.getLogger(__name__) ...@@ -28,7 +28,7 @@ _logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'DifferentiableMixedLayer', 'DifferentiableMixedInput', 'DifferentiableMixedLayer', 'DifferentiableMixedInput',
'DifferentiableMixedRepeat', 'DifferentiableMixedCell', 'DifferentiableMixedRepeat', 'DifferentiableMixedCell',
'MixedOpDifferentiablePolicy', 'MixedOpDifferentiablePolicy', 'GumbelSoftmax',
] ]
......
...@@ -18,10 +18,10 @@ import torch.nn as nn ...@@ -18,10 +18,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
import nni.retiarii.nn.pytorch as retiarii_nn import nni.nas.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX from nni.nas.nn.pytorch.choice import ValueChoiceX
from .base import BaseSuperNetModule from .base import BaseSuperNetModule
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant
...@@ -236,7 +236,7 @@ class MixedLinear(MixedOperation, nn.Linear): ...@@ -236,7 +236,7 @@ class MixedLinear(MixedOperation, nn.Linear):
Prefix of weight and bias will be sliced. Prefix of weight and bias will be sliced.
""" """
bound_type = retiarii_nn.Linear bound_type = nas_nn.Linear
argument_list = ['in_features', 'out_features'] argument_list = ['in_features', 'out_features']
def super_init_argument(self, name: str, value_choice: ValueChoiceX): def super_init_argument(self, name: str, value_choice: ValueChoiceX):
...@@ -294,7 +294,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -294,7 +294,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
□ □ □ □ □ □ □ □ □ □ □ □ □ □ □ □ □ □ □ □
""" """
bound_type = retiarii_nn.Conv2d bound_type = nas_nn.Conv2d
argument_list = [ argument_list = [
'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'dilation', 'groups' 'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'dilation', 'groups'
] ]
...@@ -427,7 +427,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): ...@@ -427,7 +427,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
PyTorch BatchNorm supports a case where momentum can be none, which is not supported here. PyTorch BatchNorm supports a case where momentum can be none, which is not supported here.
""" """
bound_type = retiarii_nn.BatchNorm2d bound_type = nas_nn.BatchNorm2d
argument_list = ['num_features', 'eps', 'momentum'] argument_list = ['num_features', 'eps', 'momentum']
def super_init_argument(self, name: str, value_choice: ValueChoiceX): def super_init_argument(self, name: str, value_choice: ValueChoiceX):
...@@ -488,7 +488,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm): ...@@ -488,7 +488,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
eps is required to be float. eps is required to be float.
""" """
bound_type = retiarii_nn.LayerNorm bound_type = nas_nn.LayerNorm
argument_list = ['normalized_shape', 'eps'] argument_list = ['normalized_shape', 'eps']
@staticmethod @staticmethod
...@@ -565,7 +565,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -565,7 +565,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``. All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``.
""" """
bound_type = retiarii_nn.MultiheadAttention bound_type = nas_nn.MultiheadAttention
argument_list = ['embed_dim', 'num_heads', 'kdim', 'vdim', 'dropout'] argument_list = ['embed_dim', 'num_heads', 'kdim', 'vdim', 'dropout']
def __post_init__(self): def __post_init__(self):
......
...@@ -11,9 +11,9 @@ import torch ...@@ -11,9 +11,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from nni.retiarii.nn.pytorch.api import ValueChoiceX from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs from nni.nas.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule from .base import BaseSuperNetModule
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import BaseStrategy
from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution
from .hpo import TPEStrategy, TPE
from .rl import PolicyBasedRL
from .oneshot import DARTS, Proxyless, GumbelDARTS, ENAS, RandomOneShot
...@@ -21,9 +21,9 @@ from tianshou.env.worker import EnvWorker ...@@ -21,9 +21,9 @@ from tianshou.env.worker import EnvWorker
from typing_extensions import TypedDict from typing_extensions import TypedDict
from nni.nas.execution import submit_models, wait_models
from nni.nas.execution.common import ModelStatus
from .utils import get_targeted_model from .utils import get_targeted_model
from ..graph import ModelStatus
from ..execution import submit_models, wait_models
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import abc import abc
from typing import List, Any from typing import List, Any
from ..graph import Model from nni.nas.execution.common import Model
from ..mutator import Mutator from nni.nas.mutable import Mutator
class BaseStrategy(abc.ABC): class BaseStrategy(abc.ABC):
......
...@@ -8,7 +8,8 @@ import random ...@@ -8,7 +8,8 @@ import random
import time import time
from typing import Any, Dict, List, Sequence, Optional from typing import Any, Dict, List, Sequence, Optional
from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted from nni.nas.execution import submit_models, query_available_resources, budget_exhausted
from nni.nas.mutable import InvalidMutation, Sampler
from .base import BaseStrategy from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model from .utils import dry_run_for_search_space, get_targeted_model, filter_model
......
...@@ -6,9 +6,10 @@ import os ...@@ -6,9 +6,10 @@ import os
import random import random
import string import string
from .. import Sampler, codegen, utils from nni.nas import Sampler, utils
from ..execution.base import BaseGraphData from nni.nas.execution.pytorch import codegen
from ..execution.utils import get_mutation_summary from nni.nas.execution.pytorch.graph import BaseGraphData
from nni.nas.execution.common import get_mutation_summary
from .base import BaseStrategy from .base import BaseStrategy
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment