Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.sampling import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.strategy import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule._operation_utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule._singlepathnas import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.base import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.differentiable import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.operation import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.proxyless import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.sampling import *
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -163,7 +162,7 @@ def replace_layer_choice(root_module, init_fn, modules=None): ...@@ -163,7 +162,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
list[tuple[str, nn.Module]] list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules. A list from layer choice keys (names) and replaced modules.
""" """
return _replace_module_with_type(root_module, init_fn, (LayerChoice, nn.LayerChoice), modules) return _replace_module_with_type(root_module, init_fn, nn.LayerChoice, modules)
def replace_input_choice(root_module, init_fn, modules=None): def replace_input_choice(root_module, init_fn, modules=None):
...@@ -184,7 +183,7 @@ def replace_input_choice(root_module, init_fn, modules=None): ...@@ -184,7 +183,7 @@ def replace_input_choice(root_module, init_fn, modules=None):
list[tuple[str, nn.Module]] list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules. A list from layer choice keys (names) and replaced modules.
""" """
return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules) return _replace_module_with_type(root_module, init_fn, nn.InputChoice, modules)
class InterleavedTrainValDataLoader(DataLoader): class InterleavedTrainValDataLoader(DataLoader):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.common.graph_op import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.tensorflow.op_def import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.op_def import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.utils.serializer import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.base import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.bruteforce import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.evolution import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.debug import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.oneshot import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment