Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.rl import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.hpo import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.trial_entry import main
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.utils.misc import *
......@@ -9,8 +9,8 @@
"nni/common/device.py",
"nni/common/graph_utils.py",
"nni/compression",
"nni/nas/tensorflow",
"nni/nas/pytorch",
"nni/nas/execution/pytorch/cgo",
"nni/nas/evaluator/pytorch/cgo",
"nni/retiarii/execution/cgo_engine.py",
"nni/retiarii/execution/logical_optimizer",
"nni/retiarii/evaluator/pytorch/cgo",
......
......@@ -32,6 +32,8 @@ try:
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer
import nni.retiarii.integration_api
module_import_failed = False
except ImportError:
module_import_failed = True
......
......@@ -14,7 +14,7 @@ import nni.runtime.platform.test
import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.hub.pytorch as searchspace
from nni.retiarii import fixed_arch
from nni.retiarii.execution.utils import _unpack_if_only_one
from nni.retiarii.execution.utils import unpack_if_only_one
from nni.retiarii.mutator import InvalidMutation, Sampler
from nni.retiarii.nn.pytorch.mutator import extract_mutation_from_pt_module
......@@ -58,7 +58,7 @@ def _test_searchspace_on_dataset(searchspace, dataset='cifar10', arch=None):
if arch is None:
model = try_mutation_until_success(model, mutators, 10)
arch = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
arch = {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model.history}
print('Selected model:', arch)
with fixed_arch(arch):
......
......@@ -56,7 +56,10 @@ class MockExecutionEngine(AbstractExecutionEngine):
def _reset_execution_engine(engine=None):
nni.retiarii.execution.api._execution_engine = engine
# Use the new NAS reset
# nni.retiarii.execution.api._execution_engine = engine
import nni.nas.execution.api
nni.nas.execution.api._execution_engine = engine
class Net(nn.Module):
......
......@@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nni.retiarii.nn.pytorch
import nni.nas.nn.pytorch
import torch
......
......@@ -4,6 +4,7 @@ import unittest
from pathlib import Path
import nni.retiarii
import nni.retiarii.integration_api
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution import set_execution_engine
......
import json
from pathlib import Path
import sys
from nni.common.framework import get_default_framework, set_default_framework
from nni.retiarii import *
# FIXME
import nni.retiarii.debug_configs
original_framework = nni.retiarii.debug_configs.framework
original_framework = get_default_framework()
max_pool = Operation.new('MaxPool2D', {'pool_size': 2})
avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2})
......@@ -14,11 +12,11 @@ global_pool = Operation.new('GlobalAveragePooling2D')
def setup_module(module):
nni.retiarii.debug_configs.framework = 'tensorflow'
set_default_framework('tensorflow')
def teardown_module(module):
nni.retiarii.debug_configs.framework = original_framework
set_default_framework(original_framework)
class DebugSampler(Sampler):
......
......@@ -15,7 +15,7 @@ from nni.retiarii import InvalidMutation, Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.evaluator import FunctionalEvaluator
from nni.retiarii.execution.utils import _unpack_if_only_one
from nni.retiarii.execution.utils import unpack_if_only_one
from nni.retiarii.experiment.pytorch import preprocess_model
from nni.retiarii.graph import Model
from nni.retiarii.nn.pytorch.api import ValueChoice
......@@ -827,7 +827,7 @@ class Python(GraphIR):
graph_engine = False
def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
mutation = {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model_ir.history}
with ContextStack('fixed', mutation):
model = model_ir.python_class(**model_ir.python_init_params)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment