Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
......@@ -7,8 +7,8 @@ import logging
import random
import time
from ..execution import query_available_resources, submit_models
from ..graph import ModelStatus
from nni.nas.execution import query_available_resources, submit_models
from nni.nas.execution.common import ModelStatus
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Wrappers of HPO tuners as NAS strategy."""
import logging
import time
from typing import Optional
from .. import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted
from nni.nas import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted
from .base import BaseStrategy
_logger = logging.getLogger(__name__)
......
......@@ -4,7 +4,7 @@
from .base import BaseStrategy
try:
from nni.retiarii.oneshot.pytorch.strategy import ( # pylint: disable=unused-import
from nni.nas.oneshot.pytorch.strategy import ( # pylint: disable=unused-import
DARTS, GumbelDARTS, Proxyless, ENAS, RandomOneShot
)
except ImportError as import_err:
......
......@@ -4,9 +4,9 @@
import logging
from typing import Optional, Callable
from nni.nas.execution import query_available_resources
from .base import BaseStrategy
from .utils import dry_run_for_search_space
from ..execution import query_available_resources
try:
has_tianshou = True
......
......@@ -4,8 +4,8 @@
import collections
import logging
from typing import Dict, Any, List
from ..graph import Model
from ..mutator import Mutator, Sampler
from nni.nas.execution.common import Model
from nni.nas.mutable import Mutator, Sampler
_logger = logging.getLogger(__name__)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .misc import *
from .serializer import *
......@@ -11,7 +11,11 @@ from pathlib import Path
from nni.common.hpo_utils import ParameterSpec
__all__ = ['NoContextError', 'ContextStack', 'ModelNamespace', 'original_state_dict_hooks']
__all__ = [
'NoContextError', 'ContextStack', 'ModelNamespace', 'original_state_dict_hooks',
'uid', 'import_', 'reset_uid', 'get_module_name', 'get_importable_name', 'get_current_context',
'STATE_DICT_PY_MAPPING', 'STATE_DICT_PY_MAPPING_PARTIAL',
]
def import_(target: str, allow_none: bool = False) -> Any:
......
......@@ -7,7 +7,7 @@ import warnings
from typing import Any, TypeVar, Type
from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace
from .misc import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
'is_basic_unit', 'is_model_wrapped']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.codegen import *
......@@ -107,4 +107,4 @@ class {graph_name}(K.Model):
def call(self, {inputs}):
{edges}
return {outputs}
'''
\ No newline at end of file
'''
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.converter.graph_gen import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.converter.op_types import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.converter.utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.converter.visualize import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.evaluator.functional import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.evaluator.pytorch.cgo.evaluator import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.evaluator.pytorch.cgo.trainer import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.evaluator.pytorch.lightning import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.api import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.graph import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment