"test/srt/git@developer.sourcefind.cn:change/sglang.git" did not exist on "ce86979355d4e96f4ad610a8f100edc930743359"
Unverified Commit bc6d8796 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 2) - update imports (#5025)

parent 867871b2
...@@ -5,10 +5,8 @@ import os ...@@ -5,10 +5,8 @@ import os
import random import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
from ..graph import Model from nni.nas.execution.common import Model, receive_trial_parameters, get_mutation_dict
from ..integration_api import receive_trial_parameters from .graph import BaseExecutionEngine
from .base import BaseExecutionEngine
from .utils import get_mutation_dict
class BenchmarkGraphData: class BenchmarkGraphData:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .engine import *
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
from __future__ import annotations from __future__ import annotations
__all__ = ['CGOExecutionEngine', 'TrialSubmission']
import logging import logging
import os import os
import random import random
...@@ -14,17 +16,19 @@ from dataclasses import dataclass ...@@ -14,17 +16,19 @@ from dataclasses import dataclass
from nni.common.device import GPUDevice, Device from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig from nni.experiment.config.training_services import RemoteConfig
from nni.retiarii.integration import RetiariiAdvisor from nni.nas import utils
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from nni.nas.execution.common import (
from .. import codegen, utils AbstractExecutionEngine, AbstractGraphListener, WorkerInfo,
from ..graph import Model, ModelStatus, MetricData, Node Model, ModelStatus, MetricData, Node,
from ..integration_api import send_trial, receive_trial_parameters, get_advisor RetiariiAdvisor, send_trial, receive_trial_parameters, get_advisor,
)
from nni.nas.execution.pytorch import codegen
from nni.nas.evaluator.pytorch.lightning import Lightning
from nni.nas.evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
from nni.nas.execution.pytorch.graph import BaseGraphData
from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode from .logical_optimizer.logical_plan import LogicalPlan, AbstractLogicalNode
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from ..evaluator.pytorch.lightning import Lightning
from ..evaluator.pytorch.cgo.evaluator import _MultiModelSupervisedLearningModule
from .base import BaseGraphData
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -7,8 +7,8 @@ from typing import Dict, Tuple, Any ...@@ -7,8 +7,8 @@ from typing import Dict, Tuple, Any
from nni.retiarii.utils import uid from nni.retiarii.utils import uid
from nni.common.device import Device, CPUDevice from nni.common.device import Device, CPUDevice
from ...graph import Cell, Edge, Graph, Model, Node from nni.nas.execution.common.graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation from nni.nas.execution.common.graph_op import Operation, _IOPseudoOperation
class AbstractLogicalNode(Node): class AbstractLogicalNode(Node):
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
from nni.retiarii.utils import uid from nni.nas.utils import uid
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule from nni.nas.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule
from nni.common.device import GPUDevice from nni.common.device import GPUDevice
from ...graph import Graph, Model, Node from nni.nas.execution.common.graph import Graph, Model, Node
from .interface import AbstractOptimizer from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan, from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode) OriginNode)
......
...@@ -7,12 +7,12 @@ import logging ...@@ -7,12 +7,12 @@ import logging
import re import re
from typing import Dict, List, Tuple, Any, cast from typing import Dict, List, Tuple, Any, cast
from nni.retiarii.operation import PyTorchOperation
from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice from nni.common.device import Device, GPUDevice
from nni.nas.execution.common.graph import IllegalGraphError, Edge, Graph, Node, Model
from nni.nas.execution.common.graph_op import PyTorchOperation
from nni.nas.utils import STATE_DICT_PY_MAPPING
from ..graph import IllegalGraphError, Edge, Graph, Node, Model from .op_def import ToDevice
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -215,7 +215,7 @@ import torch.nn as nn ...@@ -215,7 +215,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import nni.retiarii.nn.pytorch import nni.nas.nn.pytorch
{} {}
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .graph_gen import convert_to_graph
...@@ -5,11 +5,9 @@ import re ...@@ -5,11 +5,9 @@ import re
import torch import torch
from ..graph import Graph, Model, Node from nni.nas.execution.common import Graph, Model, Node, Cell, Operation
from ..nn.pytorch import InputChoice, Placeholder, LayerChoice from nni.nas.nn.pytorch import InputChoice, Placeholder, LayerChoice
from ..operation import Cell, Operation from nni.nas.utils import get_init_parameters_or_fail, get_importable_name
from ..serializer import get_init_parameters_or_fail
from ..utils import get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import ( from .utils import (
_convert_name, build_full_name, _without_shape_info, _convert_name, build_full_name, _without_shape_info,
......
...@@ -5,8 +5,7 @@ from typing import Optional ...@@ -5,8 +5,7 @@ from typing import Optional
from typing_extensions import TypeGuard from typing_extensions import TypeGuard
from ..operation import Cell from nni.nas.execution.common import Cell, Model, Graph, Node, Edge
from ..graph import Model, Graph, Node, Edge
def build_full_name(prefix, name, seq=None): def build_full_name(prefix, name, seq=None):
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
from __future__ import annotations from __future__ import annotations
__all__ = ['BaseGraphData', 'BaseExecutionEngine']
import logging import logging
import os import os
import random import random
...@@ -10,13 +12,14 @@ import string ...@@ -10,13 +12,14 @@ import string
from typing import Any, Dict, Iterable, List from typing import Any, Dict, Iterable, List
from nni.experiment import rest from nni.experiment import rest
from nni.retiarii.integration import RetiariiAdvisor
from .interface import AbstractExecutionEngine, AbstractGraphListener from nni.nas.execution.common import (
from .utils import get_mutation_summary AbstractExecutionEngine, AbstractGraphListener, RetiariiAdvisor, get_mutation_summary,
from .. import codegen, utils Model, ModelStatus, MetricData, Evaluator,
from ..graph import Model, ModelStatus, MetricData, Evaluator send_trial, receive_trial_parameters, get_advisor
from ..integration_api import send_trial, receive_trial_parameters, get_advisor )
from nni.nas.utils import import_
from .codegen import model_to_pytorch_script
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -146,7 +149,7 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -146,7 +149,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def pack_model_data(cls, model: Model) -> Any: def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model) mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None' assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(codegen.pytorch.model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore return BaseGraphData(model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
...@@ -159,6 +162,6 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -159,6 +162,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
os.makedirs(os.path.dirname(file_name), exist_ok=True) os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f: with open(file_name, 'w') as f:
f.write(graph_data.model_script) f.write(graph_data.model_script)
model_cls = utils.import_(f'_generated_model.{random_str}._model') model_cls = import_(f'_generated_model.{random_str}._model')
graph_data.evaluator._execute(model_cls) graph_data.evaluator._execute(model_cls)
os.remove(file_name) os.remove(file_name)
...@@ -8,7 +8,7 @@ from typing import (Any, Dict, List) ...@@ -8,7 +8,7 @@ from typing import (Any, Dict, List)
import torch import torch
import torch.nn.functional as nn_functional import torch.nn.functional as nn_functional
from ..operation import PyTorchOperation from nni.nas.execution.common import PyTorchOperation
mem_format = [ mem_format = [
......
...@@ -5,11 +5,13 @@ from typing import Dict, Any, Type, cast ...@@ -5,11 +5,13 @@ from typing import Dict, Any, Type, cast
import torch.nn as nn import torch.nn as nn
from ..graph import Evaluator, Model from nni.nas.execution.common import (
from ..integration_api import receive_trial_parameters Model, receive_trial_parameters,
from ..utils import ContextStack get_mutation_dict, mutation_dict_to_summary
from .base import BaseExecutionEngine )
from .utils import get_mutation_dict, mutation_dict_to_summary from nni.nas.evaluator import Evaluator
from nni.nas.utils import ContextStack
from .graph import BaseExecutionEngine
class PythonGraphData: class PythonGraphData:
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from ..operation import TensorFlowOperation from nni.nas.execution.common import TensorFlowOperation
class Conv2D(TensorFlowOperation): class Conv2D(TensorFlowOperation):
......
...@@ -3,28 +3,31 @@ ...@@ -3,28 +3,31 @@
""" """
Entrypoint for trials. Entrypoint for trials.
Assuming execution engine is BaseExecutionEngine.
""" """
import argparse import argparse
if __name__ == '__main__': def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('exec', choices=['base', 'py', 'cgo', 'benchmark']) parser.add_argument('exec', choices=['base', 'py', 'cgo', 'benchmark'])
args = parser.parse_args() args = parser.parse_args()
if args.exec == 'base': if args.exec == 'base':
from .execution.base import BaseExecutionEngine from .pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine engine = BaseExecutionEngine
elif args.exec == 'cgo': elif args.exec == 'cgo':
from .execution.cgo_engine import CGOExecutionEngine from .pytorch.cgo import CGOExecutionEngine
engine = CGOExecutionEngine engine = CGOExecutionEngine
elif args.exec == 'py': elif args.exec == 'py':
from .execution.python import PurePythonExecutionEngine from .pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine engine = PurePythonExecutionEngine
elif args.exec == 'benchmark': elif args.exec == 'benchmark':
from .execution.benchmark import BenchmarkExecutionEngine from .pytorch.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine engine = BenchmarkExecutionEngine
else: else:
raise ValueError(f'Unrecognized benchmark name: {args.exec}') raise ValueError(f'Unrecognized benchmark name: {args.exec}')
engine.trial_execute_graph() engine.trial_execute_graph()
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .experiment_config import *
from .engine_config import *
\ No newline at end of file
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
from __future__ import annotations from __future__ import annotations
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment', 'preprocess_model', 'debug_mutated_model']
import logging import logging
import warnings import warnings
from threading import Thread from threading import Thread
from typing import Any, List, Union, cast from typing import Any, List, cast
import colorama import colorama
...@@ -16,32 +18,27 @@ import torch.nn as nn ...@@ -16,32 +18,27 @@ import torch.nn as nn
from nni.experiment import Experiment, RunMode from nni.experiment import Experiment, RunMode
from nni.experiment.config.training_services import RemoteConfig from nni.experiment.config.training_services import RemoteConfig
from nni.nas.execution import list_models, set_execution_engine
from nni.nas.execution.common import RetiariiAdvisor, get_mutation_dict
from nni.nas.execution.pytorch.codegen import model_to_pytorch_script
from nni.nas.execution.pytorch.converter import convert_to_graph
from nni.nas.execution.pytorch.converter.graph_gen import GraphConverterWithShape
from nni.nas.evaluator import Evaluator
from nni.nas.mutable import Mutator
from nni.nas.nn.pytorch.mutator import (
extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations, process_oneshot_mutations
)
from nni.nas.utils import is_model_wrapped
from nni.nas.strategy import BaseStrategy
from nni.nas.strategy.utils import dry_run_for_formatted_search_space
from .config import ( from .config import (
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig, RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
) )
from ..codegen.pytorch import model_to_pytorch_script
from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape
from ..execution import list_models, set_execution_engine
from ..execution.utils import get_mutation_dict
from ..graph import Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import (
extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations, process_oneshot_mutations
)
from ..oneshot.interface import BaseOneShotTrainer
from ..serializer import is_model_wrapped
from ..strategy import BaseStrategy
from ..strategy.utils import dry_run_for_formatted_search_space
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
__all__ = ['RetiariiExperiment']
def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False): def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
# TODO: this logic might need to be refactored into execution engine # TODO: this logic might need to be refactored into execution engine
if oneshot: if oneshot:
...@@ -97,7 +94,7 @@ def debug_mutated_model(base_model, evaluator, applied_mutators): ...@@ -97,7 +94,7 @@ def debug_mutated_model(base_model, evaluator, applied_mutators):
a list of mutators that will be applied on the base model for generating a new model a list of mutators that will be applied on the base model for generating a new model
""" """
base_model_ir, applied_mutators = preprocess_model(base_model, evaluator, applied_mutators) base_model_ir, applied_mutators = preprocess_model(base_model, evaluator, applied_mutators)
from ..strategy.local_debug_strategy import _LocalDebugStrategy from nni.nas.strategy.debug import _LocalDebugStrategy
strategy = _LocalDebugStrategy() strategy = _LocalDebugStrategy()
strategy.run(base_model_ir, applied_mutators) strategy.run(base_model_ir, applied_mutators)
_logger.info('local debug completed!') _logger.info('local debug completed!')
...@@ -174,10 +171,10 @@ class RetiariiExperiment(Experiment): ...@@ -174,10 +171,10 @@ class RetiariiExperiment(Experiment):
""" """
def __init__(self, base_model: nn.Module, def __init__(self, base_model: nn.Module,
evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None), evaluator: Evaluator = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None), applied_mutators: List[Mutator] = cast(List[Mutator], None),
strategy: BaseStrategy = cast(BaseStrategy, None), strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)): trainer: Any = None):
super().__init__(None) super().__init__(None)
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None) self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
...@@ -190,7 +187,7 @@ class RetiariiExperiment(Experiment): ...@@ -190,7 +187,7 @@ class RetiariiExperiment(Experiment):
raise ValueError('Evaluator should not be none.') raise ValueError('Evaluator should not be none.')
self.base_model = base_model self.base_model = base_model
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator self.evaluator: Evaluator = evaluator
self.applied_mutators = applied_mutators self.applied_mutators = applied_mutators
self.strategy = strategy self.strategy = strategy
...@@ -222,10 +219,10 @@ class RetiariiExperiment(Experiment): ...@@ -222,10 +219,10 @@ class RetiariiExperiment(Experiment):
def _create_execution_engine(self, config: RetiariiExeConfig) -> None: def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
#TODO: we will probably need a execution engine factory to make this clean and elegant #TODO: we will probably need a execution engine factory to make this clean and elegant
if isinstance(config.execution_engine, BaseEngineConfig): if isinstance(config.execution_engine, BaseEngineConfig):
from ..execution.base import BaseExecutionEngine from nni.nas.execution.pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine(self.port, self.url_prefix) engine = BaseExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig): elif isinstance(config.execution_engine, CgoEngineConfig):
from ..execution.cgo_engine import CGOExecutionEngine from nni.nas.execution.pytorch.cgo import CGOExecutionEngine
assert not isinstance(config.training_service, list) \ assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \ and config.training_service.platform == 'remote', \
...@@ -238,10 +235,10 @@ class RetiariiExperiment(Experiment): ...@@ -238,10 +235,10 @@ class RetiariiExperiment(Experiment):
rest_port=self.port, rest_port=self.port,
rest_url_prefix=self.url_prefix) rest_url_prefix=self.url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig): elif isinstance(config.execution_engine, PyEngineConfig):
from ..execution.python import PurePythonExecutionEngine from nni.nas.execution.pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine(self.port, self.url_prefix) engine = PurePythonExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig): elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from ..execution.benchmark import BenchmarkExecutionEngine from nni.nas.execution.pytorch.benchmark import BenchmarkExecutionEngine
assert config.execution_engine.benchmark is not None, \ assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.' '"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark) engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
...@@ -265,12 +262,13 @@ class RetiariiExperiment(Experiment): ...@@ -265,12 +262,13 @@ class RetiariiExperiment(Experiment):
Run the experiment. Run the experiment.
This function will block until experiment finish or error. This function will block until experiment finish or error.
""" """
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer): if isinstance(self.evaluator, BaseOneShotTrainer):
# TODO: will throw a deprecation warning soon warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
# warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. ' 'We will try to convert this trainer to our new implementation to run the algorithm. '
# 'We will try to convert this trainer to our new implementation to run the algorithm. ' 'In case you want to stick to the old implementation, '
# 'In case you want to stick to the old implementation, ' 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self.evaluator.fit() self.evaluator.fit()
return return
...@@ -344,6 +342,7 @@ class RetiariiExperiment(Experiment): ...@@ -344,6 +342,7 @@ class RetiariiExperiment(Experiment):
config = self.config.canonical_copy() config = self.config.canonical_copy()
assert not isinstance(config.execution_engine, PyEngineConfig), \ assert not isinstance(config.execution_engine, PyEngineConfig), \
'You should use `dict` formatter when using Python execution engine.' 'You should use `dict` formatter when using Python execution engine.'
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer): if isinstance(self.evaluator, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.' assert top_k == 1, 'Only support top_k is 1 for now.'
return self.evaluator.export() return self.evaluator.export()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment