Unverified Commit 08986c6b authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Do Not Merge] add resume and view mode in python experiment (#3490)

parent aec4ce14
...@@ -75,6 +75,7 @@ class Experiment: ...@@ -75,6 +75,7 @@ class Experiment:
self.id: Optional[str] = None self.id: Optional[str] = None
self.port: Optional[int] = None self.port: Optional[int] = None
self._proc: Optional[Popen] = None self._proc: Optional[Popen] = None
self.mode = 'new'
args = [config, training_service] # deal with overloading args = [config, training_service] # deal with overloading
if isinstance(args[0], (str, list)): if isinstance(args[0], (str, list)):
...@@ -101,7 +102,10 @@ class Experiment: ...@@ -101,7 +102,10 @@ class Experiment:
""" """
atexit.register(self.stop) atexit.register(self.stop)
if self.mode == 'new':
self.id = management.generate_experiment_id() self.id = management.generate_experiment_id()
else:
self.config = launcher.get_stopped_experiment_config(self.id, self.mode)
if self.config.experiment_working_directory is not None: if self.config.experiment_working_directory is not None:
log_dir = Path(self.config.experiment_working_directory, self.id, 'log') log_dir = Path(self.config.experiment_working_directory, self.id, 'log')
...@@ -109,7 +113,7 @@ class Experiment: ...@@ -109,7 +113,7 @@ class Experiment:
log_dir = Path.home() / f'nni-experiments/{self.id}/log' log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug) nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc = launcher.start_experiment(self.id, self.config, port, debug) self._proc = launcher.start_experiment(self.id, self.config, port, debug, mode=self.mode)
assert self._proc is not None assert self._proc is not None
self.port = port # port will be None if start up failed self.port = port # port will be None if start up failed
...@@ -189,6 +193,42 @@ class Experiment: ...@@ -189,6 +193,42 @@ class Experiment:
_logger.info('Connect to port %d success, experiment id is %s, status is %s.', port, experiment.id, status) _logger.info('Connect to port %d success, experiment id is %s, status is %s.', port, experiment.id, status)
return experiment return experiment
@classmethod
def resume(cls, experiment_id: str, port: int, wait_completion: bool = True, debug: bool = False):
"""
Resume a stopped experiment.
Parameters
----------
experiment_id
The stopped experiment id.
"""
experiment = Experiment()
experiment.mode = 'resume'
if wait_completion:
experiment.run(port, debug)
else:
experiment.start(port, debug)
return experiment
@classmethod
def view(cls, experiment_id: str, port: int, wait_completion: bool = True, debug: bool = False):
"""
View a stopped experiment.
Parameters
----------
experiment_id
The stopped experiment id.
"""
experiment = Experiment()
experiment.mode = 'view'
if wait_completion:
experiment.run(port, debug)
else:
experiment.start(port, debug)
return experiment
def get_status(self) -> str: def get_status(self) -> str:
""" """
Return experiment status as a str. Return experiment status as a str.
......
...@@ -18,16 +18,19 @@ import nni.runtime.protocol ...@@ -18,16 +18,19 @@ import nni.runtime.protocol
from .config import ExperimentConfig from .config import ExperimentConfig
from .pipe import Pipe from .pipe import Pipe
from . import rest from . import rest
from ..tools.nnictl.config_utils import Experiments from ..tools.nnictl.config_utils import Experiments, Config
from ..tools.nnictl.nnictl_utils import update_experiment
_logger = logging.getLogger('nni.experiment') _logger = logging.getLogger('nni.experiment')
def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen: def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bool, mode: str = 'new') -> Popen:
proc = None proc = None
config.validate(initialized_tuner=False) config.validate(initialized_tuner=False)
_ensure_port_idle(port) _ensure_port_idle(port)
if mode != 'view':
if isinstance(config.training_service, list): # hybrid training service if isinstance(config.training_service, list): # hybrid training service
_ensure_port_idle(port + 1, 'Hybrid training service requires an additional port') _ensure_port_idle(port + 1, 'Hybrid training service requires an additional port')
elif config.training_service.platform in ['remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']: elif config.training_service.platform in ['remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']:
...@@ -35,12 +38,13 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo ...@@ -35,12 +38,13 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
try: try:
_logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL) _logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL)
start_time, proc = _start_rest_server(config, port, debug, exp_id) start_time, proc = _start_rest_server(config, port, debug, exp_id, mode=mode)
_logger.info('Statring web server...') _logger.info('Statring web server...')
_check_rest_server(port) _check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform, _save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory) config.experiment_name, proc.pid, config.experiment_working_directory)
if mode != 'view':
_logger.info('Setting up...') _logger.info('Setting up...')
rest.post(port, '/experiment', config.json()) rest.post(port, '/experiment', config.json())
return proc return proc
...@@ -98,7 +102,8 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None: ...@@ -98,7 +102,8 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
raise RuntimeError(f'Port {port} is not idle {message}') raise RuntimeError(f'Port {port} is not idle {message}')
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str = None) -> Tuple[int, Popen]: def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str = None,
mode: str = 'new') -> Tuple[int, Popen]:
if isinstance(config.training_service, list): if isinstance(config.training_service, list):
ts = 'hybrid' ts = 'hybrid'
else: else:
...@@ -110,12 +115,16 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim ...@@ -110,12 +115,16 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
'port': port, 'port': port,
'mode': ts, 'mode': ts,
'experiment_id': experiment_id, 'experiment_id': experiment_id,
'start_mode': 'new', 'start_mode': mode,
'log_level': 'debug' if debug else 'info', 'log_level': 'debug' if debug else 'info',
} }
if pipe_path is not None: if pipe_path is not None:
args['dispatcher_pipe'] = pipe_path args['dispatcher_pipe'] = pipe_path
if mode == 'view':
args['start_mode'] = 'resume'
args['readonly'] = 'true'
node_dir = Path(nni_node.__path__[0]) node_dir = Path(nni_node.__path__[0])
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node')) node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js') main_js = str(node_dir / 'main.js')
...@@ -150,3 +159,19 @@ def _check_rest_server(port: int, retry: int = 3) -> None: ...@@ -150,3 +159,19 @@ def _check_rest_server(port: int, retry: int = 3) -> None:
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None: def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None:
experiments_config = Experiments() experiments_config = Experiments()
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir) experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
update_experiment()
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
experiment_metadata = experiments_dict.get(exp_id)
if experiment_metadata is None:
logging.error('Id %s not exist!', exp_id)
return
if experiment_metadata['status'] != 'STOPPED':
logging.error('Only stopped experiments can be %sed!', mode)
return
experiment_config = Config(exp_id, experiment_metadata['logDir']).get_config()
config = ExperimentConfig(**experiment_config)
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment