Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
......@@ -5,10 +5,8 @@ import os
import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
from ..graph import Model
from ..integration_api import receive_trial_parameters
from .base import BaseExecutionEngine
from .utils import get_mutation_dict
from nni.nas.execution.common import Model, receive_trial_parameters, get_mutation_dict
from .graph import BaseExecutionEngine
class BenchmarkGraphData:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .engine import *
......@@ -3,6 +3,8 @@
from __future__ import annotations
__all__ = ['CGOExecutionEngine', 'TrialSubmission']
import logging
import os
import random
......@@ -14,17 +16,19 @@ from dataclasses import dataclass
from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from nni.retiarii.integration import RetiariiAdvisor
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Node
from ..integration_api import send_trial, receive_trial_parameters, get_advisor
from nni.nas import utils
from nni.nas.execution.common import (
AbstractExecutionEngine, AbstractGraphListener, WorkerInfo,
Model, ModelStatus, MetricData, Node,
RetiariiAdvisor, send_trial, receive_trial_parameters, get_advisor,
)
from nni.nas.execution.pytorch import codegen
from nni.nas.evaluator.pytorch.lightning import Lightning
from nni.nas.evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
from nni.nas.execution.pytorch.graph import BaseGraphData
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from ..evaluator.pytorch.lightning import Lightning
from ..evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
from .base import BaseGraphData
_logger = logging.getLogger(__name__)
......
......@@ -7,8 +7,8 @@ from typing import Dict, Tuple, Any
from nni.retiarii.utils import uid
from nni.common.device import Device, CPUDevice
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation
from nni.nas.execution.common.graph import Cell, Edge, Graph, Model, Node
from nni.nas.execution.common.graph_op import Operation, _IOPseudoOperation
class AbstractLogicalNode(Node):
......
......@@ -3,11 +3,11 @@
from typing import List, Dict, Tuple
from nni.retiarii.utils import uid
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
from nni.nas.utils import uid
from nni.nas.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
from nni.common.device import GPUDevice
from ...graph import Graph, Model, Node
from nni.nas.execution.common.graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode)
......
......@@ -7,12 +7,12 @@ import logging
import re
from typing import Dict, List, Tuple, Any, cast
from nni.retiarii.operation import PyTorchOperation
from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice
from nni.nas.execution.common.graph import IllegalGraphError, Edge, Graph, Node, Model
from nni.nas.execution.common.graph_op import PyTorchOperation
from nni.nas.utils import STATE_DICT_PY_MAPPING
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from .op_def import ToDevice
_logger = logging.getLogger(__name__)
......@@ -215,7 +215,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nni.retiarii.nn.pytorch
import nni.nas.nn.pytorch
{}
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .graph_gen import convert_to_graph
......@@ -5,11 +5,9 @@ import re
import torch
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, Placeholder, LayerChoice
from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail
from ..utils import get_importable_name
from nni.nas.execution.common import Graph, Model, Node, Cell, Operation
from nni.nas.nn.pytorch import InputChoice, Placeholder, LayerChoice
from nni.nas.utils import get_init_parameters_or_fail, get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import (
_convert_name, build_full_name, _without_shape_info,
......
......@@ -5,8 +5,7 @@ from typing import Optional
from typing_extensions import TypeGuard
from ..operation import Cell
from ..graph import Model, Graph, Node, Edge
from nni.nas.execution.common import Cell, Model, Graph, Node, Edge
def build_full_name(prefix, name, seq=None):
......
......@@ -3,6 +3,8 @@
from __future__ import annotations
__all__ = ['BaseGraphData', 'BaseExecutionEngine']
import logging
import os
import random
......@@ -10,13 +12,14 @@ import string
from typing import Any, Dict, Iterable, List
from nni.experiment import rest
from nni.retiarii.integration import RetiariiAdvisor
from .interface import AbstractExecutionEngine, AbstractGraphListener
from .utils import get_mutation_summary
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Evaluator
from ..integration_api import send_trial, receive_trial_parameters, get_advisor
from nni.nas.execution.common import (
AbstractExecutionEngine, AbstractGraphListener, RetiariiAdvisor, get_mutation_summary,
Model, ModelStatus, MetricData, Evaluator,
send_trial, receive_trial_parameters, get_advisor
)
from nni.nas.utils import import_
from .codegen import model_to_pytorch_script
_logger = logging.getLogger(__name__)
......@@ -146,7 +149,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(codegen.pytorch.model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
return BaseGraphData(model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
@classmethod
def trial_execute_graph(cls) -> None:
......@@ -159,6 +162,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
model_cls = utils.import_(f'_generated_model.{random_str}._model')
model_cls = import_(f'_generated_model.{random_str}._model')
graph_data.evaluator._execute(model_cls)
os.remove(file_name)
......@@ -8,7 +8,7 @@ from typing import (Any, Dict, List)
import torch
import torch.nn.functional as nn_functional
from ..operation import PyTorchOperation
from nni.nas.execution.common import PyTorchOperation
mem_format = [
......
......@@ -5,11 +5,13 @@ from typing import Dict, Any, Type, cast
import torch.nn as nn
from ..graph import Evaluator, Model
from ..integration_api import receive_trial_parameters
from ..utils import ContextStack
from .base import BaseExecutionEngine
from .utils import get_mutation_dict, mutation_dict_to_summary
from nni.nas.execution.common import (
Model, receive_trial_parameters,
get_mutation_dict, mutation_dict_to_summary
)
from nni.nas.evaluator import Evaluator
from nni.nas.utils import ContextStack
from .graph import BaseExecutionEngine
class PythonGraphData:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..operation import TensorFlowOperation
from nni.nas.execution.common import TensorFlowOperation
class Conv2D(TensorFlowOperation):
......
......@@ -3,28 +3,31 @@
"""
Entrypoint for trials.
Assuming execution engine is BaseExecutionEngine.
"""
import argparse
if __name__ == '__main__':
def main():
parser = argparse.ArgumentParser()
parser.add_argument('exec', choices=['base', 'py', 'cgo', 'benchmark'])
args = parser.parse_args()
if args.exec == 'base':
from .execution.base import BaseExecutionEngine
from .pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine
elif args.exec == 'cgo':
from .execution.cgo_engine import CGOExecutionEngine
from .pytorch.cgo import CGOExecutionEngine
engine = CGOExecutionEngine
elif args.exec == 'py':
from .execution.python import PurePythonExecutionEngine
from .pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine
elif args.exec == 'benchmark':
from .execution.benchmark import BenchmarkExecutionEngine
from .pytorch.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine
else:
raise ValueError(f'Unrecognized benchmark name: {args.exec}')
engine.trial_execute_graph()
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .experiment_config import *
from .engine_config import *
\ No newline at end of file
......@@ -3,11 +3,13 @@
from __future__ import annotations
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment', 'preprocess_model', 'debug_mutated_model']
import logging
import warnings
from threading import Thread
from typing import Any, List, Union, cast
from typing import Any, List, cast
import colorama
......@@ -16,32 +18,27 @@ import torch.nn as nn
from nni.experiment import Experiment, RunMode
from nni.experiment.config.training_services import RemoteConfig
from nni.nas.execution import list_models, set_execution_engine
from nni.nas.execution.common import RetiariiAdvisor, get_mutation_dict
from nni.nas.execution.pytorch.codegen import model_to_pytorch_script
from nni.nas.execution.pytorch.converter import convert_to_graph
from nni.nas.execution.pytorch.converter.graph_gen import GraphConverterWithShape
from nni.nas.evaluator import Evaluator
from nni.nas.mutable import Mutator
from nni.nas.nn.pytorch.mutator import (
extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations, process_oneshot_mutations
)
from nni.nas.utils import is_model_wrapped
from nni.nas.strategy import BaseStrategy
from nni.nas.strategy.utils import dry_run_for_formatted_search_space
from .config import (
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
)
from ..codegen.pytorch import model_to_pytorch_script
from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape
from ..execution import list_models, set_execution_engine
from ..execution.utils import get_mutation_dict
from ..graph import Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import (
extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations, process_oneshot_mutations
)
from ..oneshot.interface import BaseOneShotTrainer
from ..serializer import is_model_wrapped
from ..strategy import BaseStrategy
from ..strategy.utils import dry_run_for_formatted_search_space
_logger = logging.getLogger(__name__)
__all__ = ['RetiariiExperiment']
def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
# TODO: this logic might need to be refactored into execution engine
if oneshot:
......@@ -97,7 +94,7 @@ def debug_mutated_model(base_model, evaluator, applied_mutators):
a list of mutators that will be applied on the base model for generating a new model
"""
base_model_ir, applied_mutators = preprocess_model(base_model, evaluator, applied_mutators)
from ..strategy.local_debug_strategy import _LocalDebugStrategy
from nni.nas.strategy.debug import _LocalDebugStrategy
strategy = _LocalDebugStrategy()
strategy.run(base_model_ir, applied_mutators)
_logger.info('local debug completed!')
......@@ -174,10 +171,10 @@ class RetiariiExperiment(Experiment):
"""
def __init__(self, base_model: nn.Module,
evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
evaluator: Evaluator = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None),
strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
trainer: Any = None):
super().__init__(None)
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
......@@ -190,7 +187,7 @@ class RetiariiExperiment(Experiment):
raise ValueError('Evaluator should not be none.')
self.base_model = base_model
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
self.evaluator: Evaluator = evaluator
self.applied_mutators = applied_mutators
self.strategy = strategy
......@@ -222,10 +219,10 @@ class RetiariiExperiment(Experiment):
def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
#TODO: we will probably need a execution engine factory to make this clean and elegant
if isinstance(config.execution_engine, BaseEngineConfig):
from ..execution.base import BaseExecutionEngine
from nni.nas.execution.pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig):
from ..execution.cgo_engine import CGOExecutionEngine
from nni.nas.execution.pytorch.cgo import CGOExecutionEngine
assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
......@@ -238,10 +235,10 @@ class RetiariiExperiment(Experiment):
rest_port=self.port,
rest_url_prefix=self.url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from ..execution.python import PurePythonExecutionEngine
from nni.nas.execution.pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from ..execution.benchmark import BenchmarkExecutionEngine
from nni.nas.execution.pytorch.benchmark import BenchmarkExecutionEngine
assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
......@@ -265,12 +262,13 @@ class RetiariiExperiment(Experiment):
Run the experiment.
This function will block until experiment finish or error.
"""
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer):
# TODO: will throw a deprecation warning soon
# warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
# 'We will try to convert this trainer to our new implementation to run the algorithm. '
# 'In case you want to stick to the old implementation, '
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
'We will try to convert this trainer to our new implementation to run the algorithm. '
'In case you want to stick to the old implementation, '
'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self.evaluator.fit()
return
......@@ -344,6 +342,7 @@ class RetiariiExperiment(Experiment):
config = self.config.canonical_copy()
assert not isinstance(config.execution_engine, PyEngineConfig), \
'You should use `dict` formatter when using Python execution engine.'
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.evaluator.export()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment