Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.rl import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.hpo import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.strategy.utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.execution.trial_entry import main
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.nas.utils.misc import *
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
"nni/common/device.py", "nni/common/device.py",
"nni/common/graph_utils.py", "nni/common/graph_utils.py",
"nni/compression", "nni/compression",
"nni/nas/tensorflow", "nni/nas/execution/pytorch/cgo",
"nni/nas/pytorch", "nni/nas/evaluator/pytorch/cgo",
"nni/retiarii/execution/cgo_engine.py", "nni/retiarii/execution/cgo_engine.py",
"nni/retiarii/execution/logical_optimizer", "nni/retiarii/execution/logical_optimizer",
"nni/retiarii/evaluator/pytorch/cgo", "nni/retiarii/evaluator/pytorch/cgo",
......
...@@ -32,6 +32,8 @@ try: ...@@ -32,6 +32,8 @@ try:
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer
import nni.retiarii.integration_api
module_import_failed = False module_import_failed = False
except ImportError: except ImportError:
module_import_failed = True module_import_failed = True
......
...@@ -14,7 +14,7 @@ import nni.runtime.platform.test ...@@ -14,7 +14,7 @@ import nni.runtime.platform.test
import nni.retiarii.evaluator.pytorch.lightning as pl import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.hub.pytorch as searchspace import nni.retiarii.hub.pytorch as searchspace
from nni.retiarii import fixed_arch from nni.retiarii import fixed_arch
from nni.retiarii.execution.utils import _unpack_if_only_one from nni.retiarii.execution.utils import unpack_if_only_one
from nni.retiarii.mutator import InvalidMutation, Sampler from nni.retiarii.mutator import InvalidMutation, Sampler
from nni.retiarii.nn.pytorch.mutator import extract_mutation_from_pt_module from nni.retiarii.nn.pytorch.mutator import extract_mutation_from_pt_module
...@@ -58,7 +58,7 @@ def _test_searchspace_on_dataset(searchspace, dataset='cifar10', arch=None): ...@@ -58,7 +58,7 @@ def _test_searchspace_on_dataset(searchspace, dataset='cifar10', arch=None):
if arch is None: if arch is None:
model = try_mutation_until_success(model, mutators, 10) model = try_mutation_until_success(model, mutators, 10)
arch = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history} arch = {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model.history}
print('Selected model:', arch) print('Selected model:', arch)
with fixed_arch(arch): with fixed_arch(arch):
......
...@@ -56,7 +56,10 @@ class MockExecutionEngine(AbstractExecutionEngine): ...@@ -56,7 +56,10 @@ class MockExecutionEngine(AbstractExecutionEngine):
def _reset_execution_engine(engine=None): def _reset_execution_engine(engine=None):
nni.retiarii.execution.api._execution_engine = engine # Use the new NAS reset
# nni.retiarii.execution.api._execution_engine = engine
import nni.nas.execution.api
nni.nas.execution.api._execution_engine = engine
class Net(nn.Module): class Net(nn.Module):
......
...@@ -3,7 +3,7 @@ import torch.nn as nn ...@@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import nni.retiarii.nn.pytorch import nni.nas.nn.pytorch
import torch import torch
......
...@@ -4,6 +4,7 @@ import unittest ...@@ -4,6 +4,7 @@ import unittest
from pathlib import Path from pathlib import Path
import nni.retiarii import nni.retiarii
import nni.retiarii.integration_api
from nni.retiarii import Model, submit_models from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution import set_execution_engine from nni.retiarii.execution import set_execution_engine
......
import json import json
from pathlib import Path from pathlib import Path
import sys
from nni.common.framework import get_default_framework, set_default_framework
from nni.retiarii import * from nni.retiarii import *
# FIXME original_framework = get_default_framework()
import nni.retiarii.debug_configs
original_framework = nni.retiarii.debug_configs.framework
max_pool = Operation.new('MaxPool2D', {'pool_size': 2}) max_pool = Operation.new('MaxPool2D', {'pool_size': 2})
avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2}) avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2})
...@@ -14,11 +12,11 @@ global_pool = Operation.new('GlobalAveragePooling2D') ...@@ -14,11 +12,11 @@ global_pool = Operation.new('GlobalAveragePooling2D')
def setup_module(module): def setup_module(module):
nni.retiarii.debug_configs.framework = 'tensorflow' set_default_framework('tensorflow')
def teardown_module(module): def teardown_module(module):
nni.retiarii.debug_configs.framework = original_framework set_default_framework(original_framework)
class DebugSampler(Sampler): class DebugSampler(Sampler):
......
...@@ -15,7 +15,7 @@ from nni.retiarii import InvalidMutation, Sampler, basic_unit ...@@ -15,7 +15,7 @@ from nni.retiarii import InvalidMutation, Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.evaluator import FunctionalEvaluator from nni.retiarii.evaluator import FunctionalEvaluator
from nni.retiarii.execution.utils import _unpack_if_only_one from nni.retiarii.execution.utils import unpack_if_only_one
from nni.retiarii.experiment.pytorch import preprocess_model from nni.retiarii.experiment.pytorch import preprocess_model
from nni.retiarii.graph import Model from nni.retiarii.graph import Model
from nni.retiarii.nn.pytorch.api import ValueChoice from nni.retiarii.nn.pytorch.api import ValueChoice
...@@ -827,7 +827,7 @@ class Python(GraphIR): ...@@ -827,7 +827,7 @@ class Python(GraphIR):
graph_engine = False graph_engine = False
def _get_converted_pytorch_model(self, model_ir): def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history} mutation = {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model_ir.history}
with ContextStack('fixed', mutation): with ContextStack('fixed', mutation):
model = model_ir.python_class(**model_ir.python_init_params) model = model_ir.python_class(**model_ir.python_init_params)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment