"test/vscode:/vscode.git/clone" did not exist on "2a2c146c3fd967d4ff7f9d53d208e8a55bcf5729"
Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.sampling import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.strategy import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule._operation_utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule._singlepathnas import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.base import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.differentiable import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.operation import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.proxyless import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.oneshot.pytorch.supermodule.sampling import *
......@@ -12,7 +12,6 @@ import torch
from torch.utils.data import DataLoader, Dataset
import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice
_logger = logging.getLogger(__name__)
......@@ -163,7 +162,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, (LayerChoice, nn.LayerChoice), modules)
return _replace_module_with_type(root_module, init_fn, nn.LayerChoice, modules)
def replace_input_choice(root_module, init_fn, modules=None):
......@@ -184,7 +183,7 @@ def replace_input_choice(root_module, init_fn, modules=None):
list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules)
return _replace_module_with_type(root_module, init_fn, nn.InputChoice, modules)
class InterleavedTrainValDataLoader(DataLoader):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.common.graph_op import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.tensorflow.op_def import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.pytorch.op_def import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.utils.serializer import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.base import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.bruteforce import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.evolution import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.debug import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.oneshot import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment