"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "88e73ed048b4af2ae72b2feb876b44cbc5e9dae9"
Unverified Commit 2baae4d0 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] support experiment view and resume (#4985)

parent 97d067e6
......@@ -12,6 +12,7 @@ import json
import logging
from pathlib import Path
from typing import Any, List, Optional, Union
from typing_extensions import Literal
import yaml
......@@ -61,6 +62,7 @@ class ExperimentConfig(ConfigBase):
# In latter case hybrid training services can have different settings.
experiment_name: Optional[str] = None
experiment_type: Literal['hpo'] = 'hpo'
search_space_file: Optional[utils.PathLike] = None
search_space: Any = None
trial_command: Optional[str] = None # training service field
......
......@@ -15,6 +15,7 @@ __all__ = [
'fields', 'is_instance', 'validate_type', 'is_path_like',
'guess_config_type', 'guess_list_config_type',
'training_service_config_factory', 'load_training_service_config',
'load_experiment_config', 'get_experiment_cls_using_config',
'get_ipv4_address'
]
......@@ -25,7 +26,7 @@ import json
import os.path
from pathlib import Path
import socket
import typing
from typing import Tuple, TYPE_CHECKING, get_type_hints
import typeguard
......@@ -33,8 +34,12 @@ import nni.runtime.config
from .public import is_missing
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from nni.nas.experiment.pytorch import RetiariiExperiment
from nni.nas.experiment.config import RetiariiExeConfig
from ...experiment import Experiment
from ..base import ConfigBase
from ..experiment_config import ExperimentConfig
from ..training_service import TrainingServiceConfig
## handle relative path ##
......@@ -78,7 +83,7 @@ def fields(config: ConfigBase) -> list[dataclasses.Field]:
# Similar to `dataclasses.fields()`, but use `typing.get_types_hints()` to get `field.type`.
# This is useful when postponed evaluation is enabled.
ret = [copy.copy(field) for field in dataclasses.fields(config)]
types = typing.get_type_hints(type(config))
types = get_type_hints(type(config))
for field in ret:
field.type = types[field.name]
return ret
......@@ -198,3 +203,31 @@ def get_ipv4_address() -> str:
addr = s.getsockname()[0]
s.close()
return addr
def load_experiment_config(config_json: dict) -> ExperimentConfig | RetiariiExeConfig:
_, exp_conf_cls = get_experiment_cls_using_config(config_json)
return exp_conf_cls(**config_json)
def get_experiment_cls_using_config(config_json: dict) -> Tuple[type[Experiment] | type[RetiariiExperiment],
type[ExperimentConfig] | type[RetiariiExeConfig]]:
# avoid circular import and unnecessary dependency on pytorch
if 'experimentType' in config_json:
if config_json['experimentType'] == 'hpo':
from ...experiment import Experiment
from ..experiment_config import ExperimentConfig
return Experiment, ExperimentConfig
elif config_json['experimentType'] == 'nas':
from nni.nas.experiment.pytorch import RetiariiExperiment
from nni.nas.experiment.config import RetiariiExeConfig
return RetiariiExperiment, RetiariiExeConfig
else:
raise ValueError(f'Unknown experiment_type: {config_json["experimentType"]}')
else:
if 'executionEngine' in config_json:
from nni.nas.experiment.pytorch import RetiariiExperiment
from nni.nas.experiment.config import RetiariiExeConfig
return RetiariiExperiment, RetiariiExeConfig
else:
from ...experiment import Experiment
from ..experiment_config import ExperimentConfig
return Experiment, ExperimentConfig
......@@ -18,6 +18,7 @@ from typing import Any, TYPE_CHECKING, cast
from typing_extensions import Literal
from .config import ExperimentConfig
from .config.utils import load_experiment_config
from . import rest
from ..tools.nnictl.config_utils import Experiments, Config
from ..tools.nnictl.nnictl_utils import update_experiment
......@@ -203,7 +204,7 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
def get_stopped_experiment_config(exp_id, exp_dir=None):
config_json = get_stopped_experiment_config_json(exp_id, exp_dir) # type: ignore
config = ExperimentConfig(**config_json) # type: ignore
config = load_experiment_config(config_json) # type: ignore
if exp_dir and not os.path.samefile(exp_dir, config.experiment_working_directory):
msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
_logger.warning(msg, exp_dir, config.experiment_working_directory)
......
......@@ -3,8 +3,9 @@
import time
import warnings
from typing import Iterable
from typing import Iterable, cast
from nni.experiment.config.training_services import RemoteConfig
from nni.nas.execution.common import (
Model, ModelStatus,
AbstractExecutionEngine,
......@@ -14,11 +15,44 @@ from nni.nas.execution.common import (
_execution_engine = None
_default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
__all__ = ['init_execution_engine', 'get_execution_engine', 'get_and_register_default_listener',
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
def init_execution_engine(config, port, url_prefix) -> AbstractExecutionEngine:
from ..experiment.config import (
BaseEngineConfig, PyEngineConfig,
CgoEngineConfig, BenchmarkEngineConfig
)
if isinstance(config.execution_engine, BaseEngineConfig):
from .pytorch.graph import BaseExecutionEngine
return BaseExecutionEngine(port, url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig):
from .pytorch.cgo.engine import CGOExecutionEngine
assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert config.execution_engine.batch_waiting_time is not None \
and config.execution_engine.max_concurrency_cgo is not None
return CGOExecutionEngine(cast(RemoteConfig, config.training_service),
max_concurrency=config.execution_engine.max_concurrency_cgo,
batch_waiting_time=config.execution_engine.batch_waiting_time,
rest_port=port,
rest_url_prefix=url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from .pytorch.simplified import PurePythonExecutionEngine
return PurePythonExecutionEngine(port, url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from .pytorch.benchmark import BenchmarkExecutionEngine
assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.'
return BenchmarkExecutionEngine(config.execution_engine.benchmark)
else:
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine
if _execution_engine is not None:
......
......@@ -124,14 +124,22 @@ class Model:
def _load(ir: Any) -> 'Model':
model = Model(_internal=True)
for graph_name, graph_data in ir.items():
if graph_name != '_evaluator':
if graph_name not in ['_evaluator', 'model_id', 'python_class', 'python_init_params']:
Graph._load(model, graph_name, graph_data)._register()
if 'model_id' in ir: # backward compatibility
model.model_id = ir['model_id']
model.python_class = ir['python_class']
model.python_init_params = ir['python_init_params']
if '_evaluator' in ir:
model.evaluator = Evaluator._load(ir['_evaluator'])
return model
def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()}
# NOTE: only dump some necessary member variable, will be refactored
ret['model_id'] = self.model_id
ret['python_class'] = self.python_class
ret['python_init_params'] = self.python_init_params
if self.evaluator is not None:
ret['_evaluator'] = self.evaluator._dump()
return ret
......
......@@ -233,3 +233,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
else:
return value
return value
def handle_import_data(self, data):
# FIXME: ignore imported data for now, as strategy has not supported resume
pass
......@@ -5,6 +5,7 @@ import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, Union, Optional
from typing_extensions import Literal
from nni.experiment.config import utils, ExperimentConfig
......@@ -12,12 +13,20 @@ from .engine_config import ExecutionEngineConfig
__all__ = ['RetiariiExeConfig']
def execution_engine_config_factory(engine_name):
# FIXME: may move this function to experiment utils in future
# TODO: may move this function to experiment utils in future
def init_execution_engine_config(engine_config: Union[str, dict]) -> ExecutionEngineConfig:
if isinstance(engine_config, str):
engine_name = engine_config
else:
engine_name = engine_config['name']
cls = _get_ee_config_class(engine_name)
if cls is None:
raise ValueError(f'Invalid execution engine name: {engine_name}')
return cls()
engine = cls()
if isinstance(engine_config, dict):
for key, value in engine_config.items():
setattr(engine, key, value)
return engine
def _get_ee_config_class(engine_name):
for cls in ExecutionEngineConfig.__subclasses__():
......@@ -28,6 +37,7 @@ def _get_ee_config_class(engine_name):
@dataclass(init=False)
class RetiariiExeConfig(ExperimentConfig):
# FIXME: refactor this class to inherit from a new common base class with HPO config
experiment_type: Literal['nas'] = 'nas'
search_space: Any = ''
trial_code_directory: utils.PathLike = '.'
trial_command: str = '_reserved'
......@@ -42,34 +52,51 @@ class RetiariiExeConfig(ExperimentConfig):
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs):
super().__init__(training_service_platform, **kwargs)
self.execution_engine = execution_engine
if not utils.is_missing(self.execution_engine):
# this branch means kwargs is not {} and self.execution_engine has been assigned in super(),
# reassign it because super() may instantiate ExecutionEngineConfig by mistake
self.execution_engine = init_execution_engine_config(kwargs['executionEngine'])
del kwargs['executionEngine']
elif isinstance(execution_engine, str):
self.execution_engine = init_execution_engine_config(execution_engine)
else:
self.execution_engine = execution_engine
self._is_complete_config = False
if self.search_space != '' and self.trial_code_directory != '.' and self.trial_command != '_reserved':
# only experiment view and resume have complete config in init, as the config is directly loaded
self._is_complete_config = True
def _canonicalize(self, _parents):
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space))
# TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
trial_command_tmpl = '{envs} {python} -m nni.retiarii.trial_entry {execution_engine}'
if self.trial_command != '_reserved' and '-m nni.retiarii.trial_entry' not in self.trial_command:
raise ValueError(msg.format('trial_command', self.trial_command))
if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine)
_trial_command_params = {
# Default variables
'envs': '',
# TODO: maybe use sys.executable rendered in trial side (e.g., trial_runner)
'python': sys.executable,
'execution_engine': self.execution_engine.name,
# This should override the parameters above.
**(self._trial_command_params or {})
}
self.trial_command = trial_command_tmpl.format(**_trial_command_params).strip()
if not self._is_complete_config:
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space))
# TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
trial_command_tmpl = '{envs} {python} -m nni.retiarii.trial_entry {execution_engine}'
if self.trial_command != '_reserved' and '-m nni.retiarii.trial_entry' not in self.trial_command:
raise ValueError(msg.format('trial_command', self.trial_command))
# this canonicalize is necessary because users may assign new execution engine str
# after execution engine config is instantiated
if isinstance(self.execution_engine, str):
self.execution_engine = init_execution_engine_config(self.execution_engine)
_trial_command_params = {
# Default variables
'envs': '',
# TODO: maybe use sys.executable rendered in trial side (e.g., trial_runner)
'python': sys.executable,
'execution_engine': self.execution_engine.name,
# This should override the parameters above.
**(self._trial_command_params or {})
}
self.trial_command = trial_command_tmpl.format(**_trial_command_params).strip()
super()._canonicalize([self])
......@@ -6,20 +6,22 @@ from __future__ import annotations
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment', 'preprocess_model', 'debug_mutated_model']
import logging
import os
import time
import warnings
from threading import Thread
from typing import Any, List, cast
from typing import Any, List, cast, Tuple, TYPE_CHECKING
import colorama
import torch
import torch.nn as nn
from nni.experiment import Experiment, RunMode
from nni.experiment.config.training_services import RemoteConfig
from nni.common import dump, load
from nni.experiment import Experiment, RunMode, launcher
from nni.nas.execution import list_models, set_execution_engine
from nni.nas.execution.common import RetiariiAdvisor, get_mutation_dict
from nni.nas.execution.api import init_execution_engine
from nni.nas.execution.common import RetiariiAdvisor, get_mutation_dict, Model
from nni.nas.execution.pytorch.codegen import model_to_pytorch_script
from nni.nas.execution.pytorch.converter import convert_to_graph
from nni.nas.execution.pytorch.converter.graph_gen import GraphConverterWithShape
......@@ -36,6 +38,9 @@ from .config import (
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
)
if TYPE_CHECKING:
from nni.experiment.config.utils import PathLike
_logger = logging.getLogger(__name__)
......@@ -170,7 +175,7 @@ class RetiariiExperiment(Experiment):
... final_model = Net()
"""
def __init__(self, base_model: nn.Module,
def __init__(self, base_model: nn.Module = cast(nn.Module, None),
evaluator: Evaluator = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None),
strategy: BaseStrategy = cast(BaseStrategy, None),
......@@ -183,8 +188,16 @@ class RetiariiExperiment(Experiment):
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
evaluator = trainer
if evaluator is None:
raise ValueError('Evaluator should not be none.')
# base_model is None means the experiment is in resume or view mode
if base_model is not None:
if evaluator is None:
raise ValueError('Evaluator should not be none.')
# check for sanity
if not is_model_wrapped(base_model):
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
self.base_model = base_model
self.evaluator: Evaluator = evaluator
......@@ -194,58 +207,39 @@ class RetiariiExperiment(Experiment):
self._dispatcher = None
self._dispatcher_thread = None
# check for sanity
if not is_model_wrapped(base_model):
warnings.warn(colorama.Style.BRIGHT + colorama.Fore.RED +
'`@model_wrapper` is missing for the base model. The experiment might still be able to run, '
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning)
def _run_strategy(self, config: RetiariiExeConfig):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.evaluator, self.applied_mutators,
full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
dummy_input=config.execution_engine.dummy_input
if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
)
def _run_strategy(self, base_model_ir: Model, applied_mutators: List[Mutator]) -> None:
_logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
search_space = dry_run_for_formatted_search_space(base_model_ir, applied_mutators)
self.update_search_space(search_space)
self.strategy.run(base_model_ir, self.applied_mutators)
self.strategy.run(base_model_ir, applied_mutators)
_logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI
def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
#TODO: we will probably need a execution engine factory to make this clean and elegant
if isinstance(config.execution_engine, BaseEngineConfig):
from nni.nas.execution.pytorch.graph import BaseExecutionEngine
engine = BaseExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, CgoEngineConfig):
from nni.nas.execution.pytorch.cgo import CGOExecutionEngine
assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert config.execution_engine.batch_waiting_time is not None \
and config.execution_engine.max_concurrency_cgo is not None
engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service),
max_concurrency=config.execution_engine.max_concurrency_cgo,
batch_waiting_time=config.execution_engine.batch_waiting_time,
rest_port=self.port,
rest_url_prefix=self.url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from nni.nas.execution.pytorch.simplified import PurePythonExecutionEngine
engine = PurePythonExecutionEngine(self.port, self.url_prefix)
elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from nni.nas.execution.pytorch.benchmark import BenchmarkExecutionEngine
assert config.execution_engine.benchmark is not None, \
'"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
else:
raise ValueError(f'Unsupported engine type: {config.execution_engine}')
engine = init_execution_engine(config, self.port, self.url_prefix)
set_execution_engine(engine)
def _save_experiment_checkpoint(self, base_model_ir: Model, applied_mutators: List[Mutator],
strategy: BaseStrategy, exp_work_dir: PathLike) -> None:
ckp_path = os.path.join(exp_work_dir, self.id, 'checkpoint')
with open(os.path.join(ckp_path, 'nas_model'), 'w') as fp:
dump(base_model_ir._dump(), fp, pickle_size_limit=int(os.getenv('PICKLE_SIZE_LIMIT', 64 * 1024)))
with open(os.path.join(ckp_path, 'applied_mutators'), 'w') as fp:
dump(applied_mutators, fp)
with open(os.path.join(ckp_path, 'strategy'), 'w') as fp:
dump(strategy, fp)
def _load_experiment_checkpoint(self, exp_work_dir: PathLike) -> Tuple[Model, List[Mutator], BaseStrategy]:
ckp_path = os.path.join(exp_work_dir, self.id, 'checkpoint')
with open(os.path.join(ckp_path, 'nas_model'), 'r') as fp:
base_model_ir = load(fp=fp)
base_model_ir = Model._load(base_model_ir)
with open(os.path.join(ckp_path, 'applied_mutators'), 'r') as fp:
applied_mutators = load(fp=fp)
with open(os.path.join(ckp_path, 'strategy'), 'r') as fp:
strategy = load(fp=fp)
return base_model_ir, applied_mutators, strategy
def start(self, *args, **kwargs) -> None:
"""
By design, the only different between `start` and `run` is that `start` is asynchronous,
......@@ -262,7 +256,6 @@ class RetiariiExperiment(Experiment):
Run the experiment.
This function will block until experiment finish or error.
"""
from nni.retiarii.oneshot.interface import BaseOneShotTrainer
if isinstance(self.evaluator, BaseOneShotTrainer):
warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
......@@ -287,15 +280,30 @@ class RetiariiExperiment(Experiment):
self.strategy.run(base_model_ir, self.applied_mutators)
else:
ws_url = f'ws://localhost:{port}/tuner'
canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
canonicalized_config = cast(RetiariiExeConfig, canonicalized_config)
canoni_conf = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
canoni_conf = cast(RetiariiExeConfig, canoni_conf)
self._dispatcher = RetiariiAdvisor(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True)
self._dispatcher_thread.start()
# FIXME: engine cannot be created twice
self._create_execution_engine(canonicalized_config)
self._create_execution_engine(canoni_conf)
try:
self._run_strategy(canonicalized_config)
if self._action == 'create':
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.evaluator, self.applied_mutators,
full_ir=not isinstance(canoni_conf.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
dummy_input=canoni_conf.execution_engine.dummy_input
if isinstance(canoni_conf.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
)
self._save_experiment_checkpoint(base_model_ir, self.applied_mutators, self.strategy,
canoni_conf.experiment_working_directory)
elif self._action == 'resume':
base_model_ir, self.applied_mutators, self.strategy = self._load_experiment_checkpoint(
canoni_conf.experiment_working_directory)
else:
raise RuntimeError(f'The experiment mode "{self._action}" is not supposed to invoke run() method.')
self._run_strategy(base_model_ir, self.applied_mutators)
# FIXME: move this logic to strategy with a new API provided by execution engine
self._wait_completion()
except KeyboardInterrupt:
......@@ -359,3 +367,67 @@ class RetiariiExperiment(Experiment):
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]]
@staticmethod
def view(experiment_id: str, port: int = 8080, non_blocking: bool = False) -> RetiariiExperiment | None:
"""
View a stopped experiment.
Parameters
----------
experiment_id
The stopped experiment id.
port
The port of web UI.
non_blocking
If false, run in the foreground. If true, run in the background.
"""
experiment = RetiariiExperiment._view(experiment_id)
# view is nothing specific about RetiariiExperiment, directly using the method in base experiment class
super(RetiariiExperiment, experiment).start(port=port, debug=False, run_mode=RunMode.Detach)
if non_blocking:
return experiment
else:
try:
while True:
time.sleep(10)
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
experiment.stop()
@staticmethod
def resume(experiment_id: str, port: int = 8080, debug: bool = False) -> RetiariiExperiment:
"""
Resume a stopped experiment.
Parameters
----------
experiment_id
The stopped experiment id.
port
The port of web UI.
debug
Whether to start in debug mode.
"""
experiment = RetiariiExperiment._resume(experiment_id)
experiment.run(experiment.config, port=port, debug=debug)
# always return experiment for user's follow-up operations on the experiment
# wait_completion is not necessary as nas experiment is always in foreground
return experiment
@staticmethod
def _resume(exp_id, exp_dir=None):
exp = RetiariiExperiment(cast(nn.Module, None))
exp.id = exp_id
exp._action = 'resume'
exp.config = cast(RetiariiExeConfig, launcher.get_stopped_experiment_config(exp_id, exp_dir))
return exp
@staticmethod
def _view(exp_id, exp_dir=None):
exp = RetiariiExperiment(cast(nn.Module, None))
exp.id = exp_id
exp._action = 'view'
exp.config = cast(RetiariiExeConfig, launcher.get_stopped_experiment_config(exp_id, exp_dir))
return exp
......@@ -104,9 +104,14 @@ def resume_experiment(args):
legacy_launcher.resume_experiment(args)
exit()
exp = Experiment._resume(exp_id, exp_dir)
run_mode = RunMode.Foreground if foreground else RunMode.Detach
exp.start(port, debug, run_mode)
exp_cls, _ = utils.get_experiment_cls_using_config(config_json)
if exp_cls is Experiment:
exp = exp_cls._resume(exp_id, exp_dir)
run_mode = RunMode.Foreground if foreground else RunMode.Detach
exp.start(port, debug, run_mode)
else:
# exp_cls is RetiariiExperiment
exp_cls.resume(exp_id, port, debug)
def view_experiment(args):
exp_id = args.id
......@@ -118,5 +123,10 @@ def view_experiment(args):
legacy_launcher.view_experiment(args)
exit()
exp = Experiment._view(exp_id, exp_dir)
exp.start(port, run_mode=RunMode.Detach)
exp_cls, _ = utils.get_experiment_cls_using_config(config_json)
if exp_cls is Experiment:
exp = exp_cls._view(exp_id, exp_dir)
exp.start(port, run_mode=RunMode.Detach)
else:
# exp_cls is RetiariiExperiment
exp_cls.view(exp_id, port, non_blocking=True)
import multiprocessing
import os
import sys
import subprocess
import time
import pytest
......@@ -76,3 +76,44 @@ def test_exp_exit_without_stop(pytestconfig):
return
process.kill()
raise RuntimeError(f'Experiment fails to stop in {timeout} seconds.')
def test_multitrial_experiment_resume_view(pytestconfig):
# start a normal nas experiment
base_model, evaluator = _mnist_net('simple', {'max_epochs': 1})
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_id = exp.id
exp_config = RetiariiExeConfig('local')
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config._trial_command_params = nas_experiment_trial_params(pytestconfig.rootpath)
exp.run(exp_config)
ensure_success(exp)
assert isinstance(exp.export_top_models()[0], dict)
exp.stop()
# resume the above nas experiment. only tested the resume logic in the python side,
# as no more trial is executed after resume, the above experiment is already finished
print('python api resume...')
exp = RetiariiExperiment.resume(exp_id)
ensure_success(exp)
# TODO: currently `export_top_models` does not work as strategy's states are not resumed
# assert isinstance(exp.export_top_models()[0], dict)
exp.stop()
# view the above experiment in non blocking mode then stop it
print('python api view...')
exp = RetiariiExperiment.view(exp_id, non_blocking=True)
exp.stop()
# the following is nnictl resume and view
print('nnictl resume...')
new_env = os.environ.copy()
new_env['PYTHONPATH'] = str(pytestconfig.rootpath)
proc = subprocess.run(f'nnictl resume {exp_id}', shell=True, env=new_env)
assert proc.returncode == 0, 'resume nas experiment failed with code %d' % proc.returncode
print('nnictl view...')
proc = subprocess.run(f'nnictl view {exp_id}', shell=True)
assert proc.returncode == 0, 'view nas experiment failed with code %d' % proc.returncode
proc = subprocess.run(f'nnictl stop {exp_id}', shell=True)
assert proc.returncode == 0, 'stop viewed nas experiment failed with code %d' % proc.returncode
\ No newline at end of file
......@@ -28,6 +28,7 @@ minimal_class.trial_concurrency = 2
minimal_class.tuner.name = 'random'
minimal_canon = {
'experimentType': 'hpo',
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
......@@ -54,6 +55,7 @@ minimal_canon_2['tuner']['classArgs'] = {}
detailed_canon = {
'experimentName': 'test case',
'experimentType': 'hpo',
'searchSpaceFile': expand_path('assets/search_space.json'),
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
......
......@@ -43,6 +43,7 @@ minimal_class = ExperimentConfig(
)
minimal_canon = {
'experimentType': 'hpo',
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
......@@ -106,6 +107,7 @@ detailed_json = {
}
detailed_canon = {
'experimentType': 'hpo',
'searchSpace': {'a': 1},
'trialCommand': 'python main.py',
'trialCodeDirectory': os.path.realpath('.'),
......
......@@ -37,6 +37,10 @@ def _test_file(json_path):
# skip comparison of _evaluator
orig_ir.pop('_evaluator')
dump_ir.pop('_evaluator')
# skip three experiment fields
dump_ir.pop('model_id')
dump_ir.pop('python_class')
dump_ir.pop('python_init_params')
assert orig_ir == dump_ir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment