Unverified Commit 99e2991a authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Retiarii exp launch (#3424)

parent bc55eec6
...@@ -10,9 +10,11 @@ from typing import Optional, Tuple ...@@ -10,9 +10,11 @@ from typing import Optional, Tuple
import colorama import colorama
import nni_node # pylint: disable=import-error import nni_node # pylint: disable=import-error
import nni.runtime.protocol
from .config import ExperimentConfig from .config import ExperimentConfig
from .config import convert from .config import convert
from .pipe import Pipe
from . import rest from . import rest
from ..tools.nnictl.config_utils import Experiments from ..tools.nnictl.config_utils import Experiments
...@@ -48,6 +50,43 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo ...@@ -48,6 +50,43 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
proc.kill() proc.kill()
raise e raise e
def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen:
pipe = None
proc = None
config.validate(initialized_tuner=True)
_ensure_port_idle(port)
if isinstance(config.training_service, list): # hybrid training service
_ensure_port_idle(port + 1, 'Hybrid training service requires an additional port')
elif config.training_service.platform in ['remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']:
_ensure_port_idle(port + 1, f'{config.training_service.platform} requires an additional port')
try:
_logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL)
pipe = Pipe(exp_id)
start_time, proc = _start_rest_server(config, port, debug, exp_id, pipe.path)
_logger.info('Connecting IPC pipe...')
pipe_file = pipe.connect()
nni.runtime.protocol._in_file = pipe_file
nni.runtime.protocol._out_file = pipe_file
_logger.info('Statring web server...')
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory)
_logger.info('Setting up...')
_init_experiment(config, port, debug)
return proc, pipe
except Exception as e:
_logger.error('Create experiment failed')
if proc is not None:
with contextlib.suppress(Exception):
proc.kill()
if pipe is not None:
with contextlib.suppress(Exception):
pipe.close()
raise e
def _ensure_port_idle(port: int, message: Optional[str] = None) -> None: def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
sock = socket.socket() sock = socket.socket()
...@@ -57,7 +96,7 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None: ...@@ -57,7 +96,7 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
raise RuntimeError(f'Port {port} is not idle {message}') raise RuntimeError(f'Port {port} is not idle {message}')
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str) -> Tuple[int, Popen]: def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str = None) -> Tuple[int, Popen]:
if isinstance(config.training_service, list): if isinstance(config.training_service, list):
ts = 'hybrid' ts = 'hybrid'
else: else:
...@@ -72,6 +111,8 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim ...@@ -72,6 +111,8 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
'start_mode': 'new', 'start_mode': 'new',
'log_level': 'debug' if debug else 'info', 'log_level': 'debug' if debug else 'info',
} }
if pipe_path is not None:
args['dispatcher_pipe'] = pipe_path
node_dir = Path(nni_node.__path__[0]) node_dir = Path(nni_node.__path__[0])
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node')) node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
...@@ -85,8 +126,11 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim ...@@ -85,8 +126,11 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
from subprocess import CREATE_NEW_PROCESS_GROUP from subprocess import CREATE_NEW_PROCESS_GROUP
proc = Popen(cmd, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP) proc = Popen(cmd, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
else: else:
import os if pipe_path is None:
proc = Popen(cmd, cwd=node_dir, preexec_fn=os.setpgrp) import os
proc = Popen(cmd, cwd=node_dir, preexec_fn=os.setpgrp)
else:
proc = Popen(cmd, cwd=node_dir)
return int(time.time() * 1000), proc return int(time.time() * 1000), proc
......
import atexit
import logging import logging
import time import time
from dataclasses import dataclass from dataclasses import dataclass
import os
from pathlib import Path from pathlib import Path
import socket
from subprocess import Popen from subprocess import Popen
from threading import Thread from threading import Thread
import time
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
import colorama
import psutil
import torch import torch
import torch.nn as nn import torch.nn as nn
import nni.runtime.log
from nni.experiment import Experiment, TrainingServiceConfig from nni.experiment import Experiment, TrainingServiceConfig
from nni.experiment import management, launcher, rest
from nni.experiment.config import util from nni.experiment.config import util
from nni.experiment.config.base import ConfigBase, PathLike from nni.experiment.config.base import ConfigBase, PathLike
from nni.experiment.pipe import Pipe from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command
from ..converter import convert_to_graph from ..converter import convert_to_graph
from ..graph import Model, Evaluator from ..graph import Model, Evaluator
...@@ -47,6 +57,16 @@ class RetiariiExeConfig(ConfigBase): ...@@ -47,6 +57,16 @@ class RetiariiExeConfig(ConfigBase):
assert 'training_service' not in kwargs assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(platform = training_service_platform) self.training_service = util.training_service_config_factory(platform = training_service_platform)
def __setattr__(self, key, value):
fixed_attrs = {'search_space': '',
'trial_command': 'python3 -m nni.retiarii.trial_entry'}
if key in fixed_attrs and fixed_attrs[key] != value:
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
self.__dict__[key] = value
def validate(self, initialized_tuner: bool = False) -> None: def validate(self, initialized_tuner: bool = False) -> None:
super().validate() super().validate()
...@@ -131,7 +151,37 @@ class RetiariiExperiment(Experiment): ...@@ -131,7 +151,37 @@ class RetiariiExperiment(Experiment):
debug debug
Whether to start in debug mode. Whether to start in debug mode.
""" """
super().start(port, debug) atexit.register(self.stop)
self.id = management.generate_experiment_id()
if self.config.experiment_working_directory is not None:
log_dir = Path(self.config.experiment_working_directory, self.id, 'log')
else:
log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc, self._pipe = launcher.start_experiment_retiarii(self.id, self.config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
self._dispatcher = self._create_dispatcher()
self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start()
ips = [self.config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values():
for interface in interfaces:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg)
self._start_strategy() self._start_strategy()
def _create_dispatcher(self): def _create_dispatcher(self):
...@@ -151,7 +201,58 @@ class RetiariiExperiment(Experiment): ...@@ -151,7 +201,58 @@ class RetiariiExperiment(Experiment):
else: else:
assert config is not None, 'You are using classic search mode, config cannot be None!' assert config is not None, 'You are using classic search mode, config cannot be None!'
self.config = config self.config = config
super().run(port, debug) self._run(port, debug)
def _run(self, port: int = 8080, debug: bool = False) -> bool:
"""
Run the experiment.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
self.start(port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
def stop(self) -> None:
"""
Stop background experiment.
"""
_logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop)
if self.id is not None:
nni.runtime.log.stop_experiment_log(self.id)
if self._proc is not None:
try:
rest.delete(self.port, '/experiment')
except Exception as e:
_logger.exception(e)
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid)
if self._pipe is not None:
self._pipe.close()
if self._dispatcher_thread is not None:
self._dispatcher.stopping = True
self._dispatcher_thread.join(timeout=1)
self.id = None
self.port = None
self._proc = None
self._pipe = None
self._dispatcher = None
self._dispatcher_thread = None
_logger.info('Experiment stopped')
def export_top_models(self, top_n: int = 1): def export_top_models(self, top_n: int = 1):
""" """
......
...@@ -341,7 +341,6 @@ class NNIManager implements Manager { ...@@ -341,7 +341,6 @@ class NNIManager implements Manager {
} }
} }
} }
await this.trainingService.cleanUp(); await this.trainingService.cleanUp();
if (this.experimentProfile.endTime === undefined) { if (this.experimentProfile.endTime === undefined) {
this.setEndtime(); this.setEndtime();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment