Unverified Commit 2baae4d0 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] support experiment view and resume (#4985)

parent 97d067e6
...@@ -12,6 +12,7 @@ import json ...@@ -12,6 +12,7 @@ import json
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
from typing_extensions import Literal
import yaml import yaml
...@@ -61,6 +62,7 @@ class ExperimentConfig(ConfigBase): ...@@ -61,6 +62,7 @@ class ExperimentConfig(ConfigBase):
# In latter case hybrid training services can have different settings. # In latter case hybrid training services can have different settings.
experiment_name: Optional[str] = None experiment_name: Optional[str] = None
experiment_type: Literal['hpo'] = 'hpo'
search_space_file: Optional[utils.PathLike] = None search_space_file: Optional[utils.PathLike] = None
search_space: Any = None search_space: Any = None
trial_command: Optional[str] = None # training service field trial_command: Optional[str] = None # training service field
......
...@@ -15,6 +15,7 @@ __all__ = [ ...@@ -15,6 +15,7 @@ __all__ = [
'fields', 'is_instance', 'validate_type', 'is_path_like', 'fields', 'is_instance', 'validate_type', 'is_path_like',
'guess_config_type', 'guess_list_config_type', 'guess_config_type', 'guess_list_config_type',
'training_service_config_factory', 'load_training_service_config', 'training_service_config_factory', 'load_training_service_config',
'load_experiment_config', 'get_experiment_cls_using_config',
'get_ipv4_address' 'get_ipv4_address'
] ]
...@@ -25,7 +26,7 @@ import json ...@@ -25,7 +26,7 @@ import json
import os.path import os.path
from pathlib import Path from pathlib import Path
import socket import socket
import typing from typing import Tuple, TYPE_CHECKING, get_type_hints
import typeguard import typeguard
...@@ -33,8 +34,12 @@ import nni.runtime.config ...@@ -33,8 +34,12 @@ import nni.runtime.config
from .public import is_missing from .public import is_missing
if typing.TYPE_CHECKING: if TYPE_CHECKING:
from nni.nas.experiment.pytorch import RetiariiExperiment
from nni.nas.experiment.config import RetiariiExeConfig
from ...experiment import Experiment
from ..base import ConfigBase from ..base import ConfigBase
from ..experiment_config import ExperimentConfig
from ..training_service import TrainingServiceConfig from ..training_service import TrainingServiceConfig
## handle relative path ## ## handle relative path ##
...@@ -78,7 +83,7 @@ def fields(config: ConfigBase) -> list[dataclasses.Field]: ...@@ -78,7 +83,7 @@ def fields(config: ConfigBase) -> list[dataclasses.Field]:
# Similar to `dataclasses.fields()`, but use `typing.get_types_hints()` to get `field.type`. # Similar to `dataclasses.fields()`, but use `typing.get_types_hints()` to get `field.type`.
# This is useful when postponed evaluation is enabled. # This is useful when postponed evaluation is enabled.
ret = [copy.copy(field) for field in dataclasses.fields(config)] ret = [copy.copy(field) for field in dataclasses.fields(config)]
types = typing.get_type_hints(type(config)) types = get_type_hints(type(config))
for field in ret: for field in ret:
field.type = types[field.name] field.type = types[field.name]
return ret return ret
...@@ -198,3 +203,31 @@ def get_ipv4_address() -> str: ...@@ -198,3 +203,31 @@ def get_ipv4_address() -> str:
addr = s.getsockname()[0] addr = s.getsockname()[0]
s.close() s.close()
return addr return addr
def load_experiment_config(config_json: dict) -> ExperimentConfig | RetiariiExeConfig:
_, exp_conf_cls = get_experiment_cls_using_config(config_json)
return exp_conf_cls(**config_json)
def get_experiment_cls_using_config(config_json: dict) -> Tuple[type[Experiment] | type[RetiariiExperiment],
type[ExperimentConfig] | type[RetiariiExeConfig]]:
# avoid circular import and unnecessary dependency on pytorch
if 'experimentType' in config_json:
if config_json['experimentType'] == 'hpo':
from ...experiment import Experiment
from ..experiment_config import ExperimentConfig
return Experiment, ExperimentConfig
elif config_json['experimentType'] == 'nas':
from nni.nas.experiment.pytorch import RetiariiExperiment
from nni.nas.experiment.config import RetiariiExeConfig
return RetiariiExperiment, RetiariiExeConfig
else:
raise ValueError(f'Unknown experiment_type: {config_json["experimentType"]}')
else:
if 'executionEngine' in config_json:
from nni.nas.experiment.pytorch import RetiariiExperiment
from nni.nas.experiment.config import RetiariiExeConfig
return RetiariiExperiment, RetiariiExeConfig
else:
from ...experiment import Experiment
from ..experiment_config import ExperimentConfig
return Experiment, ExperimentConfig
...@@ -18,6 +18,7 @@ from typing import Any, TYPE_CHECKING, cast ...@@ -18,6 +18,7 @@ from typing import Any, TYPE_CHECKING, cast
from typing_extensions import Literal from typing_extensions import Literal
from .config import ExperimentConfig from .config import ExperimentConfig
from .config.utils import load_experiment_config
from . import rest from . import rest
from ..tools.nnictl.config_utils import Experiments, Config from ..tools.nnictl.config_utils import Experiments, Config
from ..tools.nnictl.nnictl_utils import update_experiment from ..tools.nnictl.nnictl_utils import update_experiment
...@@ -203,7 +204,7 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int, ...@@ -203,7 +204,7 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
def get_stopped_experiment_config(exp_id, exp_dir=None): def get_stopped_experiment_config(exp_id, exp_dir=None):
config_json = get_stopped_experiment_config_json(exp_id, exp_dir) # type: ignore config_json = get_stopped_experiment_config_json(exp_id, exp_dir) # type: ignore
config = ExperimentConfig(**config_json) # type: ignore config = load_experiment_config(config_json) # type: ignore
if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory): if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory):
msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)' msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
_logger.warning(msg, exp_dir, config.experiment_working_directory) _logger.warning(msg, exp_dir, config.experiment_working_directory)
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
import time import time
import warnings import warnings
from typing import Iterable from typing import Iterable, cast
from nni.experiment.config.training_services import RemoteConfig
from nni.nas.execution.common import ( from nni.nas.execution.common import (
Model, ModelStatus, Model, ModelStatus,
AbstractExecutionEngine, AbstractExecutionEngine,
...@@ -14,11 +15,44 @@ from nni.nas.execution.common import ( ...@@ -14,11 +15,44 @@ from nni.nas.execution.common import (
_execution_engine = None _execution_engine = None
_default_listener = None _default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener', __all__ = ['init_execution_engine', 'get_execution_engine', 'get_and_register_default_listener',
'list_models', 'submit_models', 'wait_models', 'query_available_resources', 'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted'] 'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
def init_execution_engine(config, port, url_prefix) -> AbstractExecutionEngine:
from ..experiment.config import (
BaseEngineConfig, PyEngineConfig,
CgoEngineConfig, BenchmarkEngineConfig
)
if isinstance(config.execution_engine, BaseEngineConfig):
from .pytorch.graph import BaseExecutionEngine
return BaseExecutionEngine(port, url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig):
from .pytorch.cgo.engine import CGOExecutionEngine
assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert config.execution_engine.batch_waiting_time is not None \
and config.execution_engine.max_concurrency_cgo is not None
return CGOExecutionEngine(cast(RemoteConfig, config.training_service),
max_concurrency=config.execution_engine.max_concurrency_cgo,
batch_waiting_time=config.execution_engine.batch_waiting_time,
rest_port=port,
rest_url_prefix=url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from .pytorch.simplified import PurePythonExecutionEngine
return PurePythonExecutionEngine(port, url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from .pytorch.benchmark import BenchmarkExecutionEngine
assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.'
return BenchmarkExecutionEngine(config.execution_engine.benchmark)
else:
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
def set_execution_engine(engine: AbstractExecutionEngine) -> None: def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine global _execution_engine
if _execution_engine is not None: if _execution_engine is not None:
......
...@@ -124,14 +124,22 @@ class Model: ...@@ -124,14 +124,22 @@ class Model:
def _load(ir: Any) -> 'Model': def _load(ir: Any) -> 'Model':
model = Model(_internal=True) model = Model(_internal=True)
for graph_name, graph_data in ir.items(): for graph_name, graph_data in ir.items():
if graph_name != '_evaluator': if graph_name not in ['_evaluator', 'model_id', 'python_class', 'python_init_params']:
Graph._load(model, graph_name, graph_data)._register() Graph._load(model, graph_name, graph_data)._register()
if 'model_id' in ir: # backward compatibility
model.model_id = ir['model_id']
model.python_class = ir['python_class']
model.python_init_params = ir['python_init_params']
if '_evaluator' in ir: if '_evaluator' in ir:
model.evaluator = Evaluator._load(ir['_evaluator']) model.evaluator = Evaluator._load(ir['_evaluator'])
return model return model
def _dump(self) -> Any: def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()} ret = {name: graph._dump() for name, graph in self.graphs.items()}
# NOTE: only dump some necessary member variable, will be refactored
ret['model_id'] = self.model_id
ret['python_class'] = self.python_class
ret['python_init_params'] = self.python_init_params
if self.evaluator is not None: if self.evaluator is not None:
ret['_evaluator'] = self.evaluator._dump() ret['_evaluator'] = self.evaluator._dump()
return ret return ret
......
...@@ -233,3 +233,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -233,3 +233,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
else: else:
return value return value
return value return value
def handle_import_data(self, data):
# FIXME: ignore imported data for now, as strategy has not supported resume
pass
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Union, Optional from typing import Any, Dict, Union, Optional
from typing_extensions import Literal
from nni.experiment.config import utils, ExperimentConfig from nni.experiment.config import utils, ExperimentConfig
...@@ -12,12 +13,20 @@ from .engine_config import ExecutionEngineConfig ...@@ -12,12 +13,20 @@ from .engine_config import ExecutionEngineConfig
__all__ = ['RetiariiExeConfig'] __all__ = ['RetiariiExeConfig']
def execution_engine_config_factory(engine_name): # TODO: may move this function to experiment utils in future
# FIXME: may move this function to experiment utils in future def init_execution_engine_config(engine_config: Union[str, dict]) -> ExecutionEngineConfig:
if isinstance(engine_config, str):
engine_name = engine_config
else:
engine_name = engine_config['name']
cls = _get_ee_config_class(engine_name) cls = _get_ee_config_class(engine_name)
if cls is None: if cls is None:
raise ValueError(f'Invalid execution engine name: {engine_name}') raise ValueError(f'Invalid execution engine name: {engine_name}')
return cls() engine = cls()
if isinstance(engine_config, dict):
for key, value in engine_config.items():
setattr(engine, key, value)
return engine
def _get_ee_config_class(engine_name): def _get_ee_config_class(engine_name):
for cls in ExecutionEngineConfig.__subclasses__(): for cls in ExecutionEngineConfig.__subclasses__():
...@@ -28,6 +37,7 @@ def _get_ee_config_class(engine_name): ...@@ -28,6 +37,7 @@ def _get_ee_config_class(engine_name):
@dataclass(init=False) @dataclass(init=False)
class RetiariiExeConfig(ExperimentConfig): class RetiariiExeConfig(ExperimentConfig):
# FIXME: refactor this class to inherit from a new common base class with HPO config # FIXME: refactor this class to inherit from a new common base class with HPO config
experiment_type: Literal['nas'] = 'nas'
search_space: Any = '' search_space: Any = ''
trial_code_directory: utils.PathLike = '.' trial_code_directory: utils.PathLike = '.'
trial_command: str = '_reserved' trial_command: str = '_reserved'
...@@ -42,9 +52,24 @@ class RetiariiExeConfig(ExperimentConfig): ...@@ -42,9 +52,24 @@ class RetiariiExeConfig(ExperimentConfig):
execution_engine: Union[str, ExecutionEngineConfig] = 'py', execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs): **kwargs):
super().__init__(training_service_platform, **kwargs) super().__init__(training_service_platform, **kwargs)
if not utils.is_missing(self.execution_engine):
# this branch means kwargs is not {} and self.execution_engine has been assigned in super(),
# reassign it because super() may instantiate ExecutionEngineConfig by mistake
self.execution_engine = init_execution_engine_config(kwargs['executionEngine'])
del kwargs['executionEngine']
elif isinstance(execution_engine, str):
self.execution_engine = init_execution_engine_config(execution_engine)
else:
self.execution_engine = execution_engine self.execution_engine = execution_engine
self._is_complete_config = False
if self.search_space != '' and self.trial_code_directory != '.' and self.trial_command != '_reserved':
# only experiment view and resume have complete config in init, as the config is directly loaded
self._is_complete_config = True
def _canonicalize(self, _parents): def _canonicalize(self, _parents):
if not self._is_complete_config:
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.' msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '': if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space)) raise ValueError(msg.format('search_space', self.search_space))
...@@ -56,8 +81,10 @@ class RetiariiExeConfig(ExperimentConfig): ...@@ -56,8 +81,10 @@ class RetiariiExeConfig(ExperimentConfig):
if self.trial_command != '_reserved' and '-m nni.retiarii.trial_entry' not in self.trial_command: if self.trial_command != '_reserved' and '-m nni.retiarii.trial_entry' not in self.trial_command:
raise ValueError(msg.format('trial_command', self.trial_command)) raise ValueError(msg.format('trial_command', self.trial_command))
# this canonicalize is necessary because users may assign new execution engine str
# after execution engine config is instantiated
if isinstance(self.execution_engine, str): if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine) self.execution_engine = init_execution_engine_config(self.execution_engine)
_trial_command_params = { _trial_command_params = {
# Default variables # Default variables
......
...@@ -6,20 +6,22 @@ from __future__ import annotations ...@@ -6,20 +6,22 @@ from __future__ import annotations
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment', 'preprocess_model', 'debug_mutated_model'] __all__ = ['RetiariiExeConfig', 'RetiariiExperiment', 'preprocess_model', 'debug_mutated_model']
import logging import logging
import os
import time
import warnings import warnings
from threading import Thread from threading import Thread
from typing import Any, List, cast from typing import Any, List, cast, Tuple, TYPE_CHECKING
import colorama import colorama
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.experiment import Experiment, RunMode from nni.common import dump, load
from nni.experiment.config.training_services import RemoteConfig from nni.experiment import Experiment, RunMode, launcher
from nni.nas.execution import list_models, set_execution_engine from nni.nas.execution import list_models, set_execution_engine
from nni.nas.execution.common import RetiariiAdvisor, get_mutation_dict from nni.nas.execution.api import init_execution_engine
from nni.nas.execution.common import RetiariiAdvisor, get_mutation_dict, Model
from nni.nas.execution.pytorch.codegen import model_to_pytorch_script from nni.nas.execution.pytorch.codegen import model_to_pytorch_script
from nni.nas.execution.pytorch.converter import convert_to_graph from nni.nas.execution.pytorch.converter import convert_to_graph
from nni.nas.execution.pytorch.converter.graph_gen import GraphConverterWithShape from nni.nas.execution.pytorch.converter.graph_gen import GraphConverterWithShape
...@@ -36,6 +38,9 @@ from .config import ( ...@@ -36,6 +38,9 @@ from .config import (
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
) )
if TYPE_CHECKING:
from nni.experiment.config.utils import PathLike
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -170,7 +175,7 @@ class RetiariiExperiment(Experiment): ...@@ -170,7 +175,7 @@ class RetiariiExperiment(Experiment):
... final_model = Net() ... final_model = Net()
""" """
def __init__(self, base_model: nn.Module, def __init__(self, base_model: nn.Module = cast(nn.Module, None),
evaluator: Evaluator = cast(Evaluator, None), evaluator: Evaluator = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None), applied_mutators: List[Mutator] = cast(List[Mutator], None),
strategy: BaseStrategy = cast(BaseStrategy, None), strategy: BaseStrategy = cast(BaseStrategy, None),
...@@ -183,8 +188,16 @@ class RetiariiExperiment(Experiment): ...@@ -183,8 +188,16 @@ class RetiariiExperiment(Experiment):
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning) 'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
evaluator = trainer evaluator = trainer
# base_model is None means the experiment is in resume or view mode
if base_model is not None:
if evaluator is None: if evaluator is None:
raise ValueError('Evaluator should not be none.') raise ValueError('Evaluator should not be none.')
# check for sanity
if not is_model_wrapped(base_model):
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
self.base_model = base_model self.base_model = base_model
self.evaluator: Evaluator = evaluator self.evaluator: Evaluator = evaluator
...@@ -194,58 +207,39 @@ class RetiariiExperiment(Experiment): ...@@ -194,58 +207,39 @@ class RetiariiExperiment(Experiment):
self._dispatcher = None self._dispatcher = None
self._dispatcher_thread = None self._dispatcher_thread = None
# check for sanity def _run_strategy(self, base_model_ir: Model, applied_mutators: List[Mutator]) -> None:
if not is_model_wrapped(base_model):
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
def _run_strategy(self, config: RetiariiExeConfig):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.evaluator, self.applied_mutators,
full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
dummy_input=config.execution_engine.dummy_input
if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
)
_logger.info('Start strategy...') _logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators) search_space = dry_run_for_formatted_search_space(base_model_ir, applied_mutators)
self.update_search_space(search_space) self.update_search_space(search_space)
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, applied_mutators)
_logger.info('Strategy exit') _logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI # TODO: find out a proper way to show no more trial message on WebUI
def _create_execution_engine(self, config: RetiariiExeConfig) -> None: def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
#TODO: we will probably need a execution engine factory to make this clean and elegant engine = init_execution_engine(config, self.port, self.url_prefix)
if isinstance(config.execution_engine, BaseEngineConfig):
from nni.nas.execution.pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig):
from nni.nas.execution.pytorch.cgo import CGOExecutionEngine
assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert config.execution_engine.batch_waiting_time is not None \
and config.execution_engine.max_concurrency_cgo is not None
engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service),
max_concurrency=config.execution_engine.max_concurrency_cgo,
batch_waiting_time=config.execution_engine.batch_waiting_time,
rest_port=self.port,
rest_url_prefix=self.url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from nni.nas.execution.pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from nni.nas.execution.pytorch.benchmark import BenchmarkExecutionEngine
assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
else:
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
set_execution_engine(engine) set_execution_engine(engine)
def _save_experiment_checkpoint(self, base_model_ir: Model, applied_mutators: List[Mutator],
strategy: BaseStrategy, exp_work_dir: PathLike) -> None:
ckp_path = os.path.join(exp_work_dir, self.id, 'checkpoint')
with open(os.path.join(ckp_path, 'nas_model'), 'w') as fp:
dump(base_model_ir._dump(), fp, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
with open(os.path.join(ckp_path, 'applied_mutators'), 'w') as fp:
dump(applied_mutators, fp)
with open(os.path.join(ckp_path, 'strategy'), 'w') as fp:
dump(strategy, fp)
def _load_experiment_checkpoint(self, exp_work_dir: PathLike) -> Tuple[Model, List[Mutator], BaseStrategy]:
ckp_path = os.path.join(exp_work_dir, self.id, 'checkpoint')
with open(os.path.join(ckp_path, 'nas_model'), 'r') as fp:
base_model_ir = load(fp=fp)
base_model_ir = Model._load(base_model_ir)
with open(os.path.join(ckp_path, 'applied_mutators'), 'r') as fp:
applied_mutators = load(fp=fp)
with open(os.path.join(ckp_path, 'strategy'), 'r') as fp:
strategy = load(fp=fp)
return base_model_ir, applied_mutators, strategy
def start(self, *args, **kwargs) -> None: def start(self, *args, **kwargs) -> None:
""" """
By design, the only different between `start` and `run` is that `start` is asynchronous, By design, the only different between `start` and `run` is that `start` is asynchronous,
...@@ -262,7 +256,6 @@ class RetiariiExperiment(Experiment): ...@@ -262,7 +256,6 @@ class RetiariiExperiment(Experiment):
Run the experiment. Run the experiment.
This function will block until experiment finish or error. This function will block until experiment finish or error.
""" """
from nni.retiarii.oneshot.interface import BaseOneShotTrainer from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer): if isinstance(self.evaluator, BaseOneShotTrainer):
warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. ' warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
...@@ -287,15 +280,30 @@ class RetiariiExperiment(Experiment): ...@@ -287,15 +280,30 @@ class RetiariiExperiment(Experiment):
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
else: else:
ws_url = f'ws://localhost:{port}/tuner' ws_url = f'ws://localhost:{port}/tuner'
canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii']) canoni_conf = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
canonicalized_config = cast(RetiariiExeConfig, canonicalized_config) canoni_conf = cast(RetiariiExeConfig, canoni_conf)
self._dispatcher = RetiariiAdvisor(ws_url) self._dispatcher = RetiariiAdvisor(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True) self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True)
self._dispatcher_thread.start() self._dispatcher_thread.start()
# FIXME: engine cannot be created twice # FIXME: engine cannot be created twice
self._create_execution_engine(canonicalized_config) self._create_execution_engine(canoni_conf)
try: try:
self._run_strategy(canonicalized_config) if self._action == 'create':
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.evaluator, self.applied_mutators,
full_ir=not isinstance(canoni_conf.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
dummy_input=canoni_conf.execution_engine.dummy_input
if isinstance(canoni_conf.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
)
self._save_experiment_checkpoint(base_model_ir, self.applied_mutators, self.strategy,
canoni_conf.experiment_working_directory)
elif self._action == 'resume':
base_model_ir, self.applied_mutators, self.strategy = self._load_experiment_checkpoint(
canoni_conf.experiment_working_directory)
else:
raise RuntimeError(f'The experiment mode "{self._action}" is not supposed to invoke run() method.')
self._run_strategy(base_model_ir, self.applied_mutators)
# FIXME: move this logic to strategy with a new API provided by execution engine # FIXME: move this logic to strategy with a new API provided by execution engine
self._wait_completion() self._wait_completion()
except KeyboardInterrupt: except KeyboardInterrupt:
...@@ -359,3 +367,67 @@ class RetiariiExperiment(Experiment): ...@@ -359,3 +367,67 @@ class RetiariiExperiment(Experiment):
return [model_to_pytorch_script(model) for model in all_models[:top_k]] return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict': elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]] return [get_mutation_dict(model) for model in all_models[:top_k]]
@staticmethod
def view(experiment_id: str, port: int = 8080, non_blocking: bool = False) -> RetiariiExperiment | None:
"""
View a stopped experiment.
Parameters
----------
experiment_id
The stopped experiment id.
port
The port of web UI.
non_blocking
If false, run in the foreground. If true, run in the background.
"""
experiment = RetiariiExperiment._view(experiment_id)
# view is nothing specific about RetiariiExperiment, directly using the method in base experiment class
super(RetiariiExperiment, experiment).start(port=port, debug=False, run_mode=RunMode.Detach)
if non_blocking:
return experiment
else:
try:
while True:
time.sleep(10)
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
experiment.stop()
@staticmethod
def resume(experiment_id: str, port: int = 8080, debug: bool = False) -> RetiariiExperiment:
"""
Resume a stopped experiment.
Parameters
----------
experiment_id
The stopped experiment id.
port
The port of web UI.
debug
Whether to start in debug mode.
"""
experiment = RetiariiExperiment._resume(experiment_id)
experiment.run(experiment.config, port=port, debug=debug)
# always return experiment for user's follow-up operations on the experiment
# wait_completion is not necessary as nas experiment is always in foreground
return experiment
@staticmethod
def _resume(exp_id, exp_dir=None):
exp = RetiariiExperiment(cast(nn.Module, None))
exp.id = exp_id
exp._action = 'resume'
exp.config = cast(RetiariiExeConfig, launcher.get_stopped_experiment_config(exp_id, exp_dir))
return exp
@staticmethod
def _view(exp_id, exp_dir=None):
exp = RetiariiExperiment(cast(nn.Module, None))
exp.id = exp_id
exp._action = 'view'
exp.config = cast(RetiariiExeConfig, launcher.get_stopped_experiment_config(exp_id, exp_dir))
return exp
...@@ -104,9 +104,14 @@ def resume_experiment(args): ...@@ -104,9 +104,14 @@ def resume_experiment(args):
legacy_launcher.resume_experiment(args) legacy_launcher.resume_experiment(args)
exit() exit()
exp = Experiment._resume(exp_id, exp_dir) exp_cls, _ = utils.get_experiment_cls_using_config(config_json)
if exp_cls is Experiment:
exp = exp_cls._resume(exp_id, exp_dir)
run_mode = RunMode.Foreground if foreground else RunMode.Detach run_mode = RunMode.Foreground if foreground else RunMode.Detach
exp.start(port, debug, run_mode) exp.start(port, debug, run_mode)
else:
# exp_cls is RetiariiExperiment
exp_cls.resume(exp_id, port, debug)
def view_experiment(args): def view_experiment(args):
exp_id = args.id exp_id = args.id
...@@ -118,5 +123,10 @@ def view_experiment(args): ...@@ -118,5 +123,10 @@ def view_experiment(args):
legacy_launcher.view_experiment(args) legacy_launcher.view_experiment(args)
exit() exit()
exp = Experiment._view(exp_id, exp_dir) exp_cls, _ = utils.get_experiment_cls_using_config(config_json)
if exp_cls is Experiment:
exp = exp_cls._view(exp_id, exp_dir)
exp.start(port, run_mode=RunMode.Detach) exp.start(port, run_mode=RunMode.Detach)
else:
# exp_cls is RetiariiExperiment
exp_cls.view(exp_id, port, non_blocking=True)
import multiprocessing import multiprocessing
import os import os
import sys import subprocess
import time import time
import pytest import pytest
...@@ -76,3 +76,44 @@ def test_exp_exit_without_stop(pytestconfig): ...@@ -76,3 +76,44 @@ def test_exp_exit_without_stop(pytestconfig):
return return
process.kill() process.kill()
raise RuntimeError(f'Experiment fails to stop in {timeout} seconds.') raise RuntimeError(f'Experiment fails to stop in {timeout} seconds.')
def test_multitrial_experiment_resume_view(pytestconfig):
# start a normal nas experiment
base_model, evaluator = _mnist_net('simple', {'max_epochs': 1})
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_id = exp.id
exp_config = RetiariiExeConfig('local')
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config._trial_command_params = nas_experiment_trial_params(pytestconfig.rootpath)
exp.run(exp_config)
ensure_success(exp)
assert isinstance(exp.export_top_models()[0], dict)
exp.stop()
# resume the above nas experiment. only tested the resume logic in the python side,
# as no more trial is executed after resume, the above experiment is already finished
print('python api resume...')
exp = RetiariiExperiment.resume(exp_id)
ensure_success(exp)
# TODO: currently `export_top_models` does not work as strategy's states are not resumed
# assert isinstance(exp.export_top_models()[0], dict)
exp.stop()
# view the above experiment in non blocking mode then stop it
print('python api view...')
exp = RetiariiExperiment.view(exp_id, non_blocking=True)
exp.stop()
# the following is nnictl resume and view
print('nnictl resume...')
new_env = os.environ.copy()
new_env['PYTHONPATH'] = str(pytestconfig.rootpath)
proc = subprocess.run(f'nnictl resume {exp_id}', shell=True, env=new_env)
assert proc.returncode == 0, 'resume nas experiment failed with code %d' % proc.returncode
print('nnictl view...')
proc = subprocess.run(f'nnictl view {exp_id}', shell=True)
assert proc.returncode == 0, 'view nas experiment failed with code %d' % proc.returncode
proc = subprocess.run(f'nnictl stop {exp_id}', shell=True)
assert proc.returncode == 0, 'stop viewed nas experiment failed with code %d' % proc.returncode
\ No newline at end of file
...@@ -28,6 +28,7 @@ minimal_class.trial_concurrency = 2 ...@@ -28,6 +28,7 @@ minimal_class.trial_concurrency = 2
minimal_class.tuner.name = 'random' minimal_class.tuner.name = 'random'
minimal_canon = { minimal_canon = {
'experimentType': 'hpo',
'searchSpace': {'a': 1}, 'searchSpace': {'a': 1},
'trialCommand': 'python main.py', 'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'), 'trialCodeDirectory': os.path.realpath('.'),
...@@ -54,6 +55,7 @@ minimal_canon_2['tuner']['classArgs'] = {} ...@@ -54,6 +55,7 @@ minimal_canon_2['tuner']['classArgs'] = {}
detailed_canon = { detailed_canon = {
'experimentName': 'test case', 'experimentName': 'test case',
'experimentType': 'hpo',
'searchSpaceFile': expand_path('assets/search_space.json'), 'searchSpaceFile': expand_path('assets/search_space.json'),
'searchSpace': {'a': 1}, 'searchSpace': {'a': 1},
'trialCommand': 'python main.py', 'trialCommand': 'python main.py',
......
...@@ -43,6 +43,7 @@ minimal_class = ExperimentConfig( ...@@ -43,6 +43,7 @@ minimal_class = ExperimentConfig(
) )
minimal_canon = { minimal_canon = {
'experimentType': 'hpo',
'searchSpace': {'a': 1}, 'searchSpace': {'a': 1},
'trialCommand': 'python main.py', 'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'), 'trialCodeDirectory': os.path.realpath('.'),
...@@ -106,6 +107,7 @@ detailed_json = { ...@@ -106,6 +107,7 @@ detailed_json = {
} }
detailed_canon = { detailed_canon = {
'experimentType': 'hpo',
'searchSpace': {'a': 1}, 'searchSpace': {'a': 1},
'trialCommand': 'python main.py', 'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'), 'trialCodeDirectory': os.path.realpath('.'),
......
...@@ -37,6 +37,10 @@ def _test_file(json_path): ...@@ -37,6 +37,10 @@ def _test_file(json_path):
# skip comparison of _evaluator # skip comparison of _evaluator
orig_ir.pop('_evaluator') orig_ir.pop('_evaluator')
dump_ir.pop('_evaluator') dump_ir.pop('_evaluator')
# skip three experiment fields
dump_ir.pop('model_id')
dump_ir.pop('python_class')
dump_ir.pop('python_init_params')
assert orig_ir == dump_ir assert orig_ir == dump_ir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment