Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
...@@ -7,8 +7,8 @@ import logging ...@@ -7,8 +7,8 @@ import logging
import random import random
import time import time
from ..execution import query_available_resources, submit_models from nni.nas.execution import query_available_resources, submit_models
from ..graph import ModelStatus from nni.nas.execution.common import ModelStatus
from .base import BaseStrategy from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model from .utils import dry_run_for_search_space, get_targeted_model, filter_model
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
"""Wrappers of HPO tuners as NAS strategy."""
import logging import logging
import time import time
from typing import Optional from typing import Optional
from .. import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted from nni.nas import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted
from .base import BaseStrategy from .base import BaseStrategy
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from .base import BaseStrategy from .base import BaseStrategy
try: try:
from nni.retiarii.oneshot.pytorch.strategy import ( # pylint: disable=unused-import from nni.nas.oneshot.pytorch.strategy import ( # pylint: disable=unused-import
DARTS, GumbelDARTS, Proxyless, ENAS, RandomOneShot DARTS, GumbelDARTS, Proxyless, ENAS, RandomOneShot
) )
except ImportError as import_err: except ImportError as import_err:
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
import logging import logging
from typing import Optional, Callable from typing import Optional, Callable
from nni.nas.execution import query_available_resources
from .base import BaseStrategy from .base import BaseStrategy
from .utils import dry_run_for_search_space from .utils import dry_run_for_search_space
from ..execution import query_available_resources
try: try:
has_tianshou = True has_tianshou = True
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import collections import collections
import logging import logging
from typing import Dict, Any, List from typing import Dict, Any, List
from ..graph import Model from nni.nas.execution.common import Model
from ..mutator import Mutator, Sampler from nni.nas.mutable import Mutator, Sampler
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .misc import *
from .serializer import *
...@@ -11,7 +11,11 @@ from pathlib import Path ...@@ -11,7 +11,11 @@ from pathlib import Path
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
__all__ = ['NoContextError', 'ContextStack', 'ModelNamespace', 'original_state_dict_hooks'] __all__ = [
'NoContextError', 'ContextStack', 'ModelNamespace', 'original_state_dict_hooks',
'uid', 'import_', 'reset_uid', 'get_module_name', 'get_importable_name', 'get_current_context',
'STATE_DICT_PY_MAPPING', 'STATE_DICT_PY_MAPPING_PARTIAL',
]
def import_(target: str, allow_none: bool = False) -> Any: def import_(target: str, allow_none: bool = False) -> Any:
......
...@@ -7,7 +7,7 @@ import warnings ...@@ -7,7 +7,7 @@ import warnings
from typing import Any, TypeVar, Type from typing import Any, TypeVar, Type
from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace from .misc import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper', __all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
'is_basic_unit', 'is_model_wrapped'] 'is_basic_unit', 'is_model_wrapped']
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.codegen import *
...@@ -107,4 +107,4 @@ class {graph_name}(K.Model): ...@@ -107,4 +107,4 @@ class {graph_name}(K.Model):
def call(self, {inputs}): def call(self, {inputs}):
{edges} {edges}
return {outputs} return {outputs}
''' '''
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.converter.graph_gen import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.converter.op_types import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.converter.utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.converter.visualize import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.evaluator.functional import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.evaluator.pytorch.cgo.evaluator import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.evaluator.pytorch.cgo.trainer import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.evaluator.pytorch.lightning import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.api import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.graph import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment