Unverified Commit 98c1a77f authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Support multiple HPO experiments in one process (#4855)

parent 5dc80762
...@@ -9,7 +9,6 @@ import base64 ...@@ -9,7 +9,6 @@ import base64
from .runtime.msg_dispatcher import MsgDispatcher from .runtime.msg_dispatcher import MsgDispatcher
from .runtime.msg_dispatcher_base import MsgDispatcherBase from .runtime.msg_dispatcher_base import MsgDispatcherBase
from .runtime.protocol import connect_websocket
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance
logger = logging.getLogger('nni.main') logger = logging.getLogger('nni.main')
...@@ -21,10 +20,6 @@ if os.environ.get('COVERAGE_PROCESS_START'): ...@@ -21,10 +20,6 @@ if os.environ.get('COVERAGE_PROCESS_START'):
def main(): def main():
# the url should be "ws://localhost:{port}/tuner" or "ws://localhost:{port}/{url_prefix}/tuner"
ws_url = os.environ['NNI_TUNER_COMMAND_CHANNEL']
connect_websocket(ws_url)
parser = argparse.ArgumentParser(description='Dispatcher command line parser') parser = argparse.ArgumentParser(description='Dispatcher command line parser')
parser.add_argument('--exp_params', type=str, required=True) parser.add_argument('--exp_params', type=str, required=True)
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
...@@ -56,7 +51,10 @@ def main(): ...@@ -56,7 +51,10 @@ def main():
assessor = _create_algo(exp_params['assessor'], 'assessor') assessor = _create_algo(exp_params['assessor'], 'assessor')
else: else:
assessor = None assessor = None
dispatcher = MsgDispatcher(tuner, assessor)
# the url should be "ws://localhost:{port}/tuner" or "ws://localhost:{port}/{url_prefix}/tuner"
url = os.environ['NNI_TUNER_COMMAND_CHANNEL']
dispatcher = MsgDispatcher(url, tuner, assessor)
try: try:
dispatcher.run() dispatcher.run()
......
...@@ -14,7 +14,7 @@ from ConfigSpace.read_and_write import pcs_new ...@@ -14,7 +14,7 @@ from ConfigSpace.read_and_write import pcs_new
import nni import nni
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.runtime.protocol import CommandType, send from nni.runtime.tuner_command_channel import CommandType
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.utils import OptimizeMode, MetricType, extract_scalar_reward from nni.utils import OptimizeMode, MetricType, extract_scalar_reward
from nni.runtime.common import multi_phase_enabled from nni.runtime.common import multi_phase_enabled
...@@ -483,7 +483,7 @@ class BOHB(MsgDispatcherBase): ...@@ -483,7 +483,7 @@ class BOHB(MsgDispatcherBase):
raise ValueError('Error: Search space is None') raise ValueError('Error: Search space is None')
# generate first brackets # generate first brackets
self.generate_new_bracket() self.generate_new_bracket()
send(CommandType.Initialized, '') self.send(CommandType.Initialized, '')
def generate_new_bracket(self): def generate_new_bracket(self):
"""generate a new bracket""" """generate a new bracket"""
...@@ -541,7 +541,7 @@ class BOHB(MsgDispatcherBase): ...@@ -541,7 +541,7 @@ class BOHB(MsgDispatcherBase):
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
'parameters': '' 'parameters': ''
} }
send(CommandType.NoMoreTrialJobs, nni.dump(ret)) self.send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None return None
assert self.generated_hyper_configs assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop(0) params = self.generated_hyper_configs.pop(0)
...@@ -572,7 +572,7 @@ class BOHB(MsgDispatcherBase): ...@@ -572,7 +572,7 @@ class BOHB(MsgDispatcherBase):
""" """
ret = self._get_one_trial_job() ret = self._get_one_trial_job()
if ret is not None: if ret is not None:
send(CommandType.NewTrialJob, nni.dump(ret)) self.send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1 self.credit -= 1
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
...@@ -664,7 +664,7 @@ class BOHB(MsgDispatcherBase): ...@@ -664,7 +664,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = one_unsatisfied['parameter_index'] ret['parameter_index'] = one_unsatisfied['parameter_index']
# update parameter_id in self.job_id_para_id_map # update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id'] self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, nni.dump(ret)) self.send(CommandType.SendTrialJobParameter, nni.dump(ret))
for _ in range(self.credit): for _ in range(self.credit):
self._request_one_trial_job() self._request_one_trial_job()
...@@ -712,7 +712,7 @@ class BOHB(MsgDispatcherBase): ...@@ -712,7 +712,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = data['parameter_index'] ret['parameter_index'] = data['parameter_index']
# update parameter_id in self.job_id_para_id_map # update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id'] self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, nni.dump(ret)) self.send(CommandType.SendTrialJobParameter, nni.dump(ret))
else: else:
assert 'value' in data assert 'value' in data
value = extract_scalar_reward(data['value']) value = extract_scalar_reward(data['value'])
......
...@@ -18,7 +18,7 @@ from nni import ClassArgsValidator ...@@ -18,7 +18,7 @@ from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled from nni.runtime.common import multi_phase_enabled
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.tuner_command_channel import CommandType
from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward
from nni import parameter_expressions from nni import parameter_expressions
...@@ -432,7 +432,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -432,7 +432,7 @@ class Hyperband(MsgDispatcherBase):
search space search space
""" """
self.handle_update_search_space(data) self.handle_update_search_space(data)
send(CommandType.Initialized, '') self.send(CommandType.Initialized, '')
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
""" """
...@@ -449,7 +449,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -449,7 +449,7 @@ class Hyperband(MsgDispatcherBase):
def _request_one_trial_job(self): def _request_one_trial_job(self):
ret = self._get_one_trial_job() ret = self._get_one_trial_job()
if ret is not None: if ret is not None:
send(CommandType.NewTrialJob, nni.dump(ret)) self.send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1 self.credit -= 1
def _get_one_trial_job(self): def _get_one_trial_job(self):
...@@ -478,7 +478,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -478,7 +478,7 @@ class Hyperband(MsgDispatcherBase):
'parameter_source': 'algorithm', 'parameter_source': 'algorithm',
'parameters': '' 'parameters': ''
} }
send(CommandType.NoMoreTrialJobs, nni.dump(ret)) self.send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None return None
assert self.generated_hyper_configs assert self.generated_hyper_configs
...@@ -553,7 +553,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -553,7 +553,7 @@ class Hyperband(MsgDispatcherBase):
if data['parameter_index'] is not None: if data['parameter_index'] is not None:
ret['parameter_index'] = data['parameter_index'] ret['parameter_index'] = data['parameter_index']
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id'] self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, nni.dump(ret)) self.send(CommandType.SendTrialJobParameter, nni.dump(ret))
else: else:
value = extract_scalar_reward(data['value']) value = extract_scalar_reward(data['value'])
bracket_id, i, _ = data['parameter_id'].split('_') bracket_id, i, _ = data['parameter_id'].split('_')
......
...@@ -24,7 +24,7 @@ from nni.experiment.config import utils ...@@ -24,7 +24,7 @@ from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment.config.training_services import RemoteConfig from nni.experiment.config.training_services import RemoteConfig
from nni.runtime.protocol import connect_websocket from nni.runtime.tuner_command_channel import TunerCommandChannel
from nni.tools.nnictl.command_utils import kill_command from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script from ..codegen import model_to_pytorch_script
...@@ -274,7 +274,8 @@ class RetiariiExperiment(Experiment): ...@@ -274,7 +274,8 @@ class RetiariiExperiment(Experiment):
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
if not isinstance(strategy, OneShotStrategy): if not isinstance(strategy, OneShotStrategy):
self._dispatcher = RetiariiAdvisor() # FIXME: Dispatcher should not be created this early.
self._dispatcher = RetiariiAdvisor('_placeholder_')
else: else:
self._dispatcher = cast(RetiariiAdvisor, None) self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread: Optional[Thread] = None self._dispatcher_thread: Optional[Thread] = None
...@@ -357,13 +358,14 @@ class RetiariiExperiment(Experiment): ...@@ -357,13 +358,14 @@ class RetiariiExperiment(Experiment):
self._proc = launcher.start_experiment('create', self.id, self.config, port, debug, # type: ignore self._proc = launcher.start_experiment('create', self.id, self.config, port, debug, # type: ignore
RunMode.Background, None, ws_url, ['retiarii']) RunMode.Background, None, ws_url, ['retiarii'])
assert self._proc is not None assert self._proc is not None
connect_websocket(ws_url)
self.port = port # port will be None if start up failed self.port = port # port will be None if start up failed
# dispatcher must be launched after pipe initialized # dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api # the logic to launch dispatcher in background should be refactored into dispatcher api
self._dispatcher = self._create_dispatcher() self._dispatcher = self._create_dispatcher()
if self._dispatcher is not None:
self._dispatcher._channel = TunerCommandChannel(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run) self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start() self._dispatcher_thread.start()
......
...@@ -9,7 +9,7 @@ import nni ...@@ -9,7 +9,7 @@ import nni
from nni.common.serializer import PayloadTooLarge from nni.common.serializer import PayloadTooLarge
from nni.common.version import version_dump from nni.common.version import version_dump
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.tuner_command_channel import CommandType
from nni.utils import MetricType from nni.utils import MetricType
from .graph import MetricData from .graph import MetricData
...@@ -48,8 +48,8 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -48,8 +48,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback final_metric_callback
""" """
def __init__(self): def __init__(self, url: str):
super(RetiariiAdvisor, self).__init__() super().__init__(url)
register_advisor(self) # register the current advisor as the "global only" advisor register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None self.search_space = None
...@@ -69,7 +69,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -69,7 +69,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
search space search space
""" """
self.handle_update_search_space(data) self.handle_update_search_space(data)
send(CommandType.Initialized, '') self.send(CommandType.Initialized, '')
def _validate_placement_constraint(self, placement_constraint): def _validate_placement_constraint(self, placement_constraint):
if placement_constraint is None: if placement_constraint is None:
...@@ -138,14 +138,14 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -138,14 +138,14 @@ class RetiariiAdvisor(MsgDispatcherBase):
# trial parameters can be super large, disable pickle size limit here # trial parameters can be super large, disable pickle size limit here
# nevertheless, there could still be blocked by pipe / nni-manager # nevertheless, there could still be blocked by pipe / nni-manager
send(CommandType.NewTrialJob, send_payload) self.send(CommandType.NewTrialJob, send_payload)
if self.send_trial_callback is not None: if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count return self.parameters_count
def mark_experiment_as_ending(self): def mark_experiment_as_ending(self):
send(CommandType.NoMoreTrialJobs, '') self.send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials): def handle_request_trial_jobs(self, num_trials):
_logger.debug('Request trial jobs: %s', num_trials) _logger.debug('Request trial jobs: %s', num_trials)
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
_multi_thread = False
_multi_phase = False _multi_phase = False
def enable_multi_thread():
global _multi_thread
_multi_thread = True
def multi_thread_enabled():
return _multi_thread
def enable_multi_phase(): def enable_multi_phase():
global _multi_phase global _multi_phase
_multi_phase = True _multi_phase = True
......
...@@ -22,7 +22,8 @@ _dispatcher_env_var_names = [ ...@@ -22,7 +22,8 @@ _dispatcher_env_var_names = [
'NNI_CHECKPOINT_DIRECTORY', 'NNI_CHECKPOINT_DIRECTORY',
'NNI_LOG_DIRECTORY', 'NNI_LOG_DIRECTORY',
'NNI_LOG_LEVEL', 'NNI_LOG_LEVEL',
'NNI_INCLUDE_INTERMEDIATE_RESULTS' 'NNI_INCLUDE_INTERMEDIATE_RESULTS',
'NNI_TUNER_COMMAND_CHANNEL',
] ]
def _load_env_vars(env_var_names): def _load_env_vars(env_var_names):
......
...@@ -7,10 +7,10 @@ from collections import defaultdict ...@@ -7,10 +7,10 @@ from collections import defaultdict
from nni import NoMoreTrialError from nni import NoMoreTrialError
from nni.assessor import AssessResult from nni.assessor import AssessResult
from .common import multi_thread_enabled, multi_phase_enabled from .common import multi_phase_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
from .msg_dispatcher_base import MsgDispatcherBase from .msg_dispatcher_base import MsgDispatcherBase
from .protocol import CommandType, send from .tuner_command_channel import CommandType
from ..common.serializer import dump, load from ..common.serializer import dump, load
from ..utils import MetricType from ..utils import MetricType
...@@ -67,8 +67,8 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p ...@@ -67,8 +67,8 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
class MsgDispatcher(MsgDispatcherBase): class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None): def __init__(self, command_channel_url, tuner, assessor=None):
super(MsgDispatcher, self).__init__() super().__init__(command_channel_url)
self.tuner = tuner self.tuner = tuner
self.assessor = assessor self.assessor = assessor
if assessor is None: if assessor is None:
...@@ -88,12 +88,12 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -88,12 +88,12 @@ class MsgDispatcher(MsgDispatcherBase):
"""Data is search space """Data is search space
""" """
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
send(CommandType.Initialized, '') self.send(CommandType.Initialized, '')
def send_trial_callback(self, id_, params): def send_trial_callback(self, id_, params):
"""For tuner to issue trial config when the config is generated """For tuner to issue trial config when the config is generated
""" """
send(CommandType.NewTrialJob, _pack_parameter(id_, params)) self.send(CommandType.NewTrialJob, _pack_parameter(id_, params))
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
# data: number or trial jobs # data: number or trial jobs
...@@ -102,10 +102,10 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -102,10 +102,10 @@ class MsgDispatcher(MsgDispatcherBase):
params_list = self.tuner.generate_multiple_parameters(ids, st_callback=self.send_trial_callback) params_list = self.tuner.generate_multiple_parameters(ids, st_callback=self.send_trial_callback)
for i, _ in enumerate(params_list): for i, _ in enumerate(params_list):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i])) self.send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
# when parameters is None. # when parameters is None.
if len(params_list) < len(ids): if len(params_list) < len(ids):
send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], '')) self.send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], ''))
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
...@@ -148,7 +148,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -148,7 +148,7 @@ class MsgDispatcher(MsgDispatcherBase):
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id']) param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
except NoMoreTrialError: except NoMoreTrialError:
param = None param = None
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], self.send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'],
parameter_index=data['parameter_index'])) parameter_index=data['parameter_index']))
else: else:
raise ValueError('Data type not supported: {}'.format(data['type'])) raise ValueError('Data type not supported: {}'.format(data['type']))
...@@ -222,7 +222,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -222,7 +222,7 @@ class MsgDispatcher(MsgDispatcherBase):
if result is AssessResult.Bad: if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id) _logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, dump(trial_job_id)) self.send(CommandType.KillTrialJob, dump(trial_job_id))
# notify tuner # notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', _logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]',
dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS) dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
...@@ -237,8 +237,5 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -237,8 +237,5 @@ class MsgDispatcher(MsgDispatcherBase):
""" """
_logger.debug('Early stop notify tuner data: [%s]', data) _logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = MetricType.FINAL data['type'] = MetricType.FINAL
if multi_thread_enabled(): data['value'] = dump(data['value'])
self._handle_final_metric_data(data) self.enqueue_command(CommandType.ReportMetricData, data)
else:
data['value'] = dump(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)
...@@ -3,14 +3,12 @@ ...@@ -3,14 +3,12 @@
import threading import threading
import logging import logging
from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty from queue import Queue, Empty
from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
from ..common import load from ..common import load
from ..recoverable import Recoverable from ..recoverable import Recoverable
from .protocol import CommandType, receive from .tuner_command_channel import CommandType, TunerCommandChannel
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -24,59 +22,52 @@ class MsgDispatcherBase(Recoverable): ...@@ -24,59 +22,52 @@ class MsgDispatcherBase(Recoverable):
Inherits this class to make your own advisor. Inherits this class to make your own advisor.
""" """
def __init__(self): def __init__(self, command_channel_url=None):
self.stopping = False self.stopping = False
if multi_thread_enabled(): if command_channel_url is None:
self.pool = ThreadPool() command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
self.thread_results = [] self._channel = TunerCommandChannel(command_channel_url)
else: self.default_command_queue = Queue()
self.default_command_queue = Queue() self.assessor_command_queue = Queue()
self.assessor_command_queue = Queue() self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,)) self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, self.worker_exceptions = []
args=(self.assessor_command_queue,))
self.default_worker.start()
self.assessor_worker.start()
self.worker_exceptions = []
def run(self): def run(self):
"""Run the tuner. """Run the tuner.
This function will never return unless raise. This function will never return unless raise.
""" """
_logger.info('Dispatcher started') _logger.info('Dispatcher started')
self._channel.connect()
self.default_worker.start()
self.assessor_worker.start()
if dispatcher_env_vars.NNI_MODE == 'resume': if dispatcher_env_vars.NNI_MODE == 'resume':
self.load_checkpoint() self.load_checkpoint()
while not self.stopping: while not self.stopping:
command, data = receive() command, data = self._channel._receive()
if data: if data:
data = load(data) data = load(data)
if command is None or command is CommandType.Terminate: if command is None or command is CommandType.Terminate:
break break
if multi_thread_enabled(): self.enqueue_command(command, data)
result = self.pool.map_async(self.process_command_thread, [(command, data)]) if self.worker_exceptions:
self.thread_results.append(result) break
if any([thread_result.ready() and not thread_result.successful() for thread_result in
self.thread_results]):
_logger.debug('Caught thread exception')
break
else:
self.enqueue_command(command, data)
if self.worker_exceptions:
break
_logger.info('Dispatcher exiting...') _logger.info('Dispatcher exiting...')
self.stopping = True self.stopping = True
if multi_thread_enabled(): self.default_worker.join()
self.pool.close() self.assessor_worker.join()
self.pool.join() self._channel.disconnect()
else:
self.default_worker.join()
self.assessor_worker.join()
_logger.info('Dispatcher terminiated') _logger.info('Dispatcher terminiated')
def send(self, command, data):
self._channel._send(command, data)
def command_queue_worker(self, command_queue): def command_queue_worker(self, command_queue):
"""Process commands in command queues. """Process commands in command queues.
""" """
...@@ -112,19 +103,6 @@ class MsgDispatcherBase(Recoverable): ...@@ -112,19 +103,6 @@ class MsgDispatcherBase(Recoverable):
if qsize >= QUEUE_LEN_WARNING_MARK: if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('assessor queue length: %d', qsize) _logger.warning('assessor queue length: %d', qsize)
def process_command_thread(self, request):
"""Worker thread to process a command.
"""
command, data = request
if multi_thread_enabled():
try:
self.process_command(command, data)
except Exception as e:
_logger.exception(str(e))
raise
else:
pass
def process_command(self, command, data): def process_command(self, command, data):
_logger.debug('process_command: command: [%s], data: [%s]', command, data) _logger.debug('process_command: command: [%s], data: [%s]', command, data)
...@@ -242,4 +220,4 @@ class MsgDispatcherBase(Recoverable): ...@@ -242,4 +220,4 @@ class MsgDispatcherBase(Recoverable):
hyper_params: the string that is sent by message dispatcher during the creation of trials. hyper_params: the string that is sent by message dispatcher during the creation of trials.
""" """
raise NotImplementedError('handle_trial_end not implemented') raise NotImplementedError('handle_trial_end not implemented')
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=unused-import
from __future__ import annotations
from .tuner_command_channel.command_type import CommandType
from .tuner_command_channel import legacy
from .tuner_command_channel import shim
_use_ws = False
def connect_websocket(url: str):
global _use_ws
_use_ws = True
shim.connect(url)
def send(command: CommandType, data: str) -> None:
if _use_ws:
shim.send(command, data)
else:
legacy.send(command, data)
def receive() -> tuple[CommandType, str] | tuple[None, None]:
if _use_ws:
return shim.receive()
else:
return legacy.receive()
# for unit test compatibility
def _set_in_file(in_file):
legacy._in_file = in_file
def _set_out_file(out_file):
legacy._out_file = out_file
def _get_out_file():
return legacy._out_file
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# Licensed under the MIT license. # Licensed under the MIT license.
""" """
The IPC channel between tuner/assessor and NNI manager. Low level APIs for algorithms to communicate with NNI manager.
Work in progress.
""" """
from .command_type import CommandType
from .channel import TunerCommandChannel
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Low level APIs for algorithms to communicate with NNI manager.
"""
from __future__ import annotations
__all__ = ['TunerCommandChannel']
from .command_type import CommandType
from .websocket import WebSocket
class TunerCommandChannel:
"""
A channel to communicate with NNI manager.
Each NNI experiment has a channel URL for tuner/assessor/strategy algorithm.
The channel can only be connected once, so for each Python side :class:`~nni.experiment.Experiment` object,
there should be exactly one corresponding ``TunerCommandChannel`` instance.
:meth:`connect` must be invoked before sending or receiving data.
The constructor does not have side effect so ``TunerCommandChannel`` can be created anywhere.
But :meth:`connect` requires an initialized NNI manager, or otherwise the behavior is unpredictable.
:meth:`_send` and :meth:`_receive` are underscore-prefixed because their signatures are scheduled to change by v3.0.
Parameters
----------
url
The command channel URL.
For now it must be like ``"ws://localhost:8080/tuner"`` or ``"ws://localhost:8080/url-prefix/tuner"``.
"""
def __init__(self, url: str):
self._channel = WebSocket(url)
def connect(self) -> None:
self._channel.connect()
def disconnect(self) -> None:
self._channel.disconnect()
# TODO: Define semantic command class like `KillTrialJob(trial_id='abc')`.
# def send(self, command: Command) -> None:
# ...
# def receive(self) -> Command | None:
# ...
def _send(self, command_type: CommandType, data: str) -> None:
command = command_type.value.decode() + data
self._channel.send(command)
def _receive(self) -> tuple[CommandType, str] | tuple[None, None]:
command = self._channel.receive()
if command is None:
raise RuntimeError('NNI manager closed connection')
command_type = CommandType(command[:2].encode())
return command_type, command[2:]
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
__all__ = [
'CommandType',
'LegacyCommandChannel',
'send',
'receive',
'_set_in_file',
'_set_out_file',
'_get_out_file',
]
import logging import logging
import os import os
import threading import threading
...@@ -18,6 +28,29 @@ try: ...@@ -18,6 +28,29 @@ try:
except OSError: except OSError:
_logger.debug('IPC pipeline not exists') _logger.debug('IPC pipeline not exists')
def _set_in_file(in_file):
global _in_file
_in_file = in_file
def _set_out_file(out_file):
global _out_file
_out_file = out_file
def _get_out_file():
return _out_file
class LegacyCommandChannel:
def connect(self):
pass
def disconnect(self):
pass
def _send(self, command, data):
send(command, data)
def _receive(self):
return receive()
def send(command, data): def send(command, data):
"""Send command to Training Service. """Send command to Training Service.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Compatibility layer for old protocol APIs.
We are working on more semantic new APIs.
"""
from __future__ import annotations
from .command_type import CommandType
from .websocket import WebSocket
_ws: WebSocket = None # type: ignore
def connect(url: str) -> None:
global _ws
_ws = WebSocket(url)
_ws.connect()
def send(command_type: CommandType, data: str) -> None:
command = command_type.value.decode() + data
_ws.send(command)
def receive() -> tuple[CommandType, str]:
command = _ws.receive()
if command is None:
raise RuntimeError('NNI manager closed connection')
command_type = CommandType(command[:2].encode())
if command_type is CommandType.Terminate:
_ws.disconnect()
return command_type, command[2:]
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import json import json
import logging
import os import os
from schema import And, Optional, Or, Regex, Schema, SchemaError from schema import And, Optional, Or, Regex, Schema, SchemaError
...@@ -77,7 +76,6 @@ class AlgoSchema: ...@@ -77,7 +76,6 @@ class AlgoSchema:
if not builtin_name or not class_args: if not builtin_name or not class_args:
return return
logging.getLogger('nni.protocol').setLevel(logging.ERROR) # we know IPC is not there, don't complain
validator = create_validator_instance(algo_type+'s', builtin_name) validator = create_validator_instance(algo_type+'s', builtin_name)
if validator: if validator:
try: try:
......
...@@ -88,11 +88,14 @@ def create_experiment(args): ...@@ -88,11 +88,14 @@ def create_experiment(args):
exp = Experiment(config) exp = Experiment(config)
exp.url_prefix = url_prefix exp.url_prefix = url_prefix
run_mode = RunMode.Foreground if foreground else RunMode.Detach
exp.start(port, debug, run_mode)
_logger.info(f'To stop experiment run "nnictl stop {exp.id}" or "nnictl stop --all"') if foreground:
_logger.info('Reference: https://nni.readthedocs.io/en/stable/reference/nnictl.html') exp.run(port, debug=debug)
else:
exp.start(port, debug, RunMode.Detach)
_logger.info(f'To stop experiment run "nnictl stop {exp.id}" or "nnictl stop --all"')
_logger.info('Reference: https://nni.readthedocs.io/en/stable/reference/nnictl.html')
def resume_experiment(args): def resume_experiment(args):
exp_id = args.id exp_id = args.id
......
...@@ -10,6 +10,7 @@ from pathlib import Path ...@@ -10,6 +10,7 @@ from pathlib import Path
import nni import nni
import nni.runtime.platform.test import nni.runtime.platform.test
from nni.runtime.tuner_command_channel import legacy as protocol
import json import json
try: try:
...@@ -262,7 +263,11 @@ class CGOEngineTest(unittest.TestCase): ...@@ -262,7 +263,11 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer() opt = DedupInputOptimizer()
opt.convert(lp) opt.convert(lp)
advisor = RetiariiAdvisor() advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)] available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0) cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
...@@ -281,7 +286,11 @@ class CGOEngineTest(unittest.TestCase): ...@@ -281,7 +286,11 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer() opt = DedupInputOptimizer()
opt.convert(lp) opt.convert(lp)
advisor = RetiariiAdvisor() advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)] available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0) cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
...@@ -296,14 +305,17 @@ class CGOEngineTest(unittest.TestCase): ...@@ -296,14 +305,17 @@ class CGOEngineTest(unittest.TestCase):
_reset() _reset()
nni.retiarii.debug_configs.framework = 'pytorch' nni.retiarii.debug_configs.framework = 'pytorch'
os.makedirs('generated', exist_ok=True) os.makedirs('generated', exist_ok=True)
from nni.runtime import protocol
import nni.runtime.platform.test as tt import nni.runtime.platform.test as tt
protocol._set_out_file(open('generated/debug_protocol_out_file.py', 'wb')) protocol._set_out_file(open('generated/debug_protocol_out_file.py', 'wb'))
protocol._set_in_file(open('generated/debug_protocol_out_file.py', 'rb')) protocol._set_in_file(open('generated/debug_protocol_out_file.py', 'rb'))
models = _load_mnist(2) models = _load_mnist(2)
advisor = RetiariiAdvisor() advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1), cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1),
GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0) GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0)
set_execution_engine(cgo_engine) set_execution_engine(cgo_engine)
......
...@@ -11,7 +11,7 @@ from nni.retiarii.execution.base import BaseExecutionEngine ...@@ -11,7 +11,7 @@ from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine from nni.retiarii.execution.python import PurePythonExecutionEngine
from nni.retiarii.graph import DebugEvaluator from nni.retiarii.graph import DebugEvaluator
from nni.retiarii.integration import RetiariiAdvisor from nni.retiarii.integration import RetiariiAdvisor
from nni.runtime.tuner_command_channel.legacy import *
class EngineTest(unittest.TestCase): class EngineTest(unittest.TestCase):
def test_codegen(self): def test_codegen(self):
...@@ -25,7 +25,11 @@ class EngineTest(unittest.TestCase): ...@@ -25,7 +25,11 @@ class EngineTest(unittest.TestCase):
def test_base_execution_engine(self): def test_base_execution_engine(self):
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor() advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
set_execution_engine(BaseExecutionEngine()) set_execution_engine(BaseExecutionEngine())
with open(self.enclosing_dir / 'mnist_pytorch.json') as f: with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
model = Model._load(json.load(f)) model = Model._load(json.load(f))
...@@ -38,7 +42,11 @@ class EngineTest(unittest.TestCase): ...@@ -38,7 +42,11 @@ class EngineTest(unittest.TestCase):
def test_py_execution_engine(self): def test_py_execution_engine(self):
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor() advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
set_execution_engine(PurePythonExecutionEngine()) set_execution_engine(PurePythonExecutionEngine())
model = Model._load({ model = Model._load({
'_model': { '_model': {
...@@ -63,11 +71,9 @@ class EngineTest(unittest.TestCase): ...@@ -63,11 +71,9 @@ class EngineTest(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.enclosing_dir = Path(__file__).parent self.enclosing_dir = Path(__file__).parent
os.makedirs(self.enclosing_dir / 'generated', exist_ok=True) os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
from nni.runtime import protocol _set_out_file(open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb'))
protocol._set_out_file(open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb'))
def tearDown(self) -> None: def tearDown(self) -> None:
from nni.runtime import protocol _get_out_file().close()
protocol._get_out_file().close()
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
...@@ -8,8 +8,7 @@ from unittest import TestCase, main ...@@ -8,8 +8,7 @@ from unittest import TestCase, main
from nni.assessor import Assessor, AssessResult from nni.assessor import Assessor, AssessResult
from nni.runtime import msg_dispatcher_base as msg_dispatcher_base from nni.runtime import msg_dispatcher_base as msg_dispatcher_base
from nni.runtime.msg_dispatcher import MsgDispatcher from nni.runtime.msg_dispatcher import MsgDispatcher
from nni.runtime import protocol from nni.runtime.tuner_command_channel.legacy import *
from nni.runtime.protocol import CommandType, send, receive
_trials = [] _trials = []
_end_trials = [] _end_trials = []
...@@ -34,15 +33,15 @@ _out_buf = BytesIO() ...@@ -34,15 +33,15 @@ _out_buf = BytesIO()
def _reverse_io(): def _reverse_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
protocol._set_out_file(_in_buf) _set_out_file(_in_buf)
protocol._set_in_file(_out_buf) _set_in_file(_out_buf)
def _restore_io(): def _restore_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
protocol._set_in_file(_in_buf) _set_in_file(_in_buf)
protocol._set_out_file(_out_buf) _set_out_file(_out_buf)
class AssessorTestCase(TestCase): class AssessorTestCase(TestCase):
...@@ -58,7 +57,8 @@ class AssessorTestCase(TestCase): ...@@ -58,7 +57,8 @@ class AssessorTestCase(TestCase):
_restore_io() _restore_io()
assessor = NaiveAssessor() assessor = NaiveAssessor()
dispatcher = MsgDispatcher(None, assessor) dispatcher = MsgDispatcher('ws://_placeholder_', None, assessor)
dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False msg_dispatcher_base._worker_fast_exit_on_terminate = False
dispatcher.run() dispatcher.run()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment