Unverified Commit 63697ec5 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Route tuner and assessor commands to 2 seperate queues (#891)

1. Route tuner and assessor commands to 2 seperate queues  issue #841
2. Allow tuner to leverage intermediate result when trial is early stopped.  issue #843
parent c297650a
...@@ -231,12 +231,17 @@ machineList: ...@@ -231,12 +231,17 @@ machineList:
* __classArgs__ * __classArgs__
__classArgs__ specifies the arguments of tuner algorithm. __classArgs__ specifies the arguments of tuner algorithm.
* __gpuNum__
* __gpuNum__
__gpuNum__ specifies the gpu number to run the tuner process. The value of this field should be a positive number. __gpuNum__ specifies the gpu number to run the tuner process. The value of this field should be a positive number.
Note: users could only specify one way to set tuner, for example, set {tunerName, optimizationMode} or {tunerCommand, tunerCwd}, and could not set them both. Note: users could only specify one way to set tuner, for example, set {tunerName, optimizationMode} or {tunerCommand, tunerCwd}, and could not set them both.
* __includeIntermediateResults__
If __includeIntermediateResults__ is true, the last intermediate result of the trial that is early stopped by assessor is sent to tuner as final result. The default value of __includeIntermediateResults__ is false.
* __assessor__ * __assessor__
* Description * Description
......
...@@ -46,6 +46,7 @@ interface ExperimentParams { ...@@ -46,6 +46,7 @@ interface ExperimentParams {
classFileName?: string; classFileName?: string;
checkpointDir: string; checkpointDir: string;
gpuNum?: number; gpuNum?: number;
includeIntermediateResults?: boolean;
}; };
assessor?: { assessor?: {
className: string; className: string;
......
...@@ -277,11 +277,17 @@ class NNIManager implements Manager { ...@@ -277,11 +277,17 @@ class NNIManager implements Manager {
newCwd = cwd; newCwd = cwd;
} }
// TO DO: add CUDA_VISIBLE_DEVICES // TO DO: add CUDA_VISIBLE_DEVICES
let includeIntermediateResultsEnv: boolean | undefined = false;
if (this.experimentProfile.params.tuner !== undefined) {
includeIntermediateResultsEnv = this.experimentProfile.params.tuner.includeIntermediateResults;
}
let nniEnv = { let nniEnv = {
NNI_MODE: mode, NNI_MODE: mode,
NNI_CHECKPOINT_DIRECTORY: dataDirectory, NNI_CHECKPOINT_DIRECTORY: dataDirectory,
NNI_LOG_DIRECTORY: getLogDir(), NNI_LOG_DIRECTORY: getLogDir(),
NNI_LOG_LEVEL: getLogLevel() NNI_LOG_LEVEL: getLogLevel(),
NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv
}; };
let newEnv = Object.assign({}, process.env, nniEnv); let newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = spawn(command, [], { const tunerProc: ChildProcess = spawn(command, [], {
......
...@@ -106,7 +106,7 @@ describe('core/ipcInterface.terminate', (): void => { ...@@ -106,7 +106,7 @@ describe('core/ipcInterface.terminate', (): void => {
assert.ok(!procError); assert.ok(!procError);
deferred.resolve(); deferred.resolve();
}, },
2000); 5000);
return deferred.promise; return deferred.promise;
}); });
......
...@@ -159,7 +159,8 @@ export namespace ValidationSchemas { ...@@ -159,7 +159,8 @@ export namespace ValidationSchemas {
className: joi.string(), className: joi.string(),
classArgs: joi.any(), classArgs: joi.any(),
gpuNum: joi.number().min(0), gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow('') checkpointDir: joi.string().allow(''),
includeIntermediateResults: joi.boolean()
}), }),
assessor: joi.object({ assessor: joi.object({
builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'), builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'),
......
...@@ -18,14 +18,15 @@ ...@@ -18,14 +18,15 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ================================================================================================== # ==================================================================================================
import os
import logging import logging
from collections import defaultdict from collections import defaultdict
import json_tricks import json_tricks
import threading
from .protocol import CommandType, send from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult from .assessor import AssessResult
from .common import multi_thread_enabled
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -70,7 +71,7 @@ def _pack_parameter(parameter_id, params, customized=False): ...@@ -70,7 +71,7 @@ def _pack_parameter(parameter_id, params, customized=False):
class MsgDispatcher(MsgDispatcherBase): class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None): def __init__(self, tuner, assessor=None):
super().__init__() super(MsgDispatcher, self).__init__()
self.tuner = tuner self.tuner = tuner
self.assessor = assessor self.assessor = assessor
if assessor is None: if assessor is None:
...@@ -87,9 +88,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -87,9 +88,8 @@ class MsgDispatcher(MsgDispatcherBase):
self.assessor.save_checkpoint() self.assessor.save_checkpoint()
def handle_initialize(self, data): def handle_initialize(self, data):
''' """Data is search space
data is search space """
'''
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
send(CommandType.Initialized, '') send(CommandType.Initialized, '')
return True return True
...@@ -126,12 +126,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -126,12 +126,7 @@ class MsgDispatcher(MsgDispatcherBase):
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
""" """
if data['type'] == 'FINAL': if data['type'] == 'FINAL':
id_ = data['parameter_id'] self._handle_final_metric_data(data)
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
elif data['type'] == 'PERIODICAL': elif data['type'] == 'PERIODICAL':
if self.assessor is not None: if self.assessor is not None:
self._handle_intermediate_metric_data(data) self._handle_intermediate_metric_data(data)
...@@ -157,7 +152,19 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -157,7 +152,19 @@ class MsgDispatcher(MsgDispatcherBase):
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED') self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
return True return True
def _handle_final_metric_data(self, data):
"""Call tuner to process final results
"""
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results
"""
if data['type'] != 'PERIODICAL': if data['type'] != 'PERIODICAL':
return True return True
if self.assessor is None: if self.assessor is None:
...@@ -187,5 +194,20 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -187,5 +194,20 @@ class MsgDispatcher(MsgDispatcherBase):
if result is AssessResult.Bad: if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id) _logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id)) send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS'))
if os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS') == 'true':
self._earlystop_notify_tuner(data)
else: else:
_logger.debug('GOOD') _logger.debug('GOOD')
def _earlystop_notify_tuner(self, data):
"""Send last intermediate result as final result to tuner in case the
trial is early stopped.
"""
_logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = 'FINAL'
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
self.enqueue_command(CommandType.ReportMetricData, data)
...@@ -19,14 +19,13 @@ ...@@ -19,14 +19,13 @@
# ================================================================================================== # ==================================================================================================
#import json_tricks #import json_tricks
import logging
import os import os
from queue import Queue import threading
import sys import logging
from multiprocessing.dummy import Pool as ThreadPool from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty
import json_tricks import json_tricks
from .common import init_logger, multi_thread_enabled from .common import init_logger, multi_thread_enabled
from .recoverable import Recoverable from .recoverable import Recoverable
from .protocol import CommandType, receive from .protocol import CommandType, receive
...@@ -34,57 +33,109 @@ from .protocol import CommandType, receive ...@@ -34,57 +33,109 @@ from .protocol import CommandType, receive
init_logger('dispatcher.log') init_logger('dispatcher.log')
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
QUEUE_LEN_WARNING_MARK = 20
_worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable): class MsgDispatcherBase(Recoverable):
def __init__(self): def __init__(self):
if multi_thread_enabled(): if multi_thread_enabled():
self.pool = ThreadPool() self.pool = ThreadPool()
self.thread_results = [] self.thread_results = []
else:
self.stopping = False
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,))
self.default_worker.start()
self.assessor_worker.start()
self.worker_exceptions = []
def run(self): def run(self):
"""Run the tuner. """Run the tuner.
This function will never return unless raise. This function will never return unless raise.
""" """
_logger.info('Start dispatcher')
mode = os.getenv('NNI_MODE') mode = os.getenv('NNI_MODE')
if mode == 'resume': if mode == 'resume':
self.load_checkpoint() self.load_checkpoint()
while True: while True:
_logger.debug('waiting receive_message')
command, data = receive() command, data = receive()
if data:
data = json_tricks.loads(data)
if command is None or command is CommandType.Terminate: if command is None or command is CommandType.Terminate:
break break
if multi_thread_enabled(): if multi_thread_enabled():
result = self.pool.map_async(self.handle_request_thread, [(command, data)]) result = self.pool.map_async(self.process_command_thread, [(command, data)])
self.thread_results.append(result) self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]): if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]):
_logger.debug('Caught thread exception') _logger.debug('Caught thread exception')
break break
else: else:
self.handle_request((command, data)) self.enqueue_command(command, data)
_logger.info('Dispatcher exiting...')
self.stopping = True
if multi_thread_enabled(): if multi_thread_enabled():
self.pool.close() self.pool.close()
self.pool.join() self.pool.join()
else:
self.default_worker.join()
self.assessor_worker.join()
_logger.info('Terminated by NNI manager') _logger.info('Terminated by NNI manager')
def handle_request_thread(self, request): def command_queue_worker(self, command_queue):
"""Process commands in command queues.
"""
while True:
try:
# set timeout to ensure self.stopping is checked periodically
command, data = command_queue.get(timeout=3)
try:
self.process_command(command, data)
except Exception as e:
_logger.exception(e)
self.worker_exceptions.append(e)
break
except Empty:
pass
if self.stopping and (_worker_fast_exit_on_terminate or command_queue.empty()):
break
def enqueue_command(self, command, data):
"""Enqueue command into command queues
"""
if command == CommandType.TrialEnd or (command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'):
self.assessor_command_queue.put((command, data))
else:
self.default_command_queue.put((command, data))
qsize = self.default_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('default queue length: %d', qsize)
qsize = self.assessor_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('assessor queue length: %d', qsize)
def process_command_thread(self, request):
"""Worker thread to process a command.
"""
command, data = request
if multi_thread_enabled(): if multi_thread_enabled():
try: try:
self.handle_request(request) self.process_command(command, data)
except Exception as e: except Exception as e:
_logger.exception(str(e)) _logger.exception(str(e))
raise raise
else: else:
pass pass
def handle_request(self, request): def process_command(self, command, data):
command, data = request _logger.debug('process_command: command: [{}], data: [{}]'.format(command, data))
_logger.debug('handle request: command: [{}], data: [{}]'.format(command, data))
if data:
data = json_tricks.loads(data)
command_handlers = { command_handlers = {
# Tunner commands: # Tunner commands:
......
...@@ -75,7 +75,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p ...@@ -75,7 +75,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
class MultiPhaseMsgDispatcher(MsgDispatcherBase): class MultiPhaseMsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None): def __init__(self, tuner, assessor=None):
super() super(MultiPhaseMsgDispatcher, self).__init__()
self.tuner = tuner self.tuner = tuner
self.assessor = assessor self.assessor = assessor
if assessor is None: if assessor is None:
......
...@@ -42,11 +42,10 @@ class CommandType(Enum): ...@@ -42,11 +42,10 @@ class CommandType(Enum):
NoMoreTrialJobs = b'NO' NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI' KillTrialJob = b'KI'
_lock = threading.Lock()
try: try:
_in_file = open(3, 'rb') _in_file = open(3, 'rb')
_out_file = open(4, 'wb') _out_file = open(4, 'wb')
_lock = threading.Lock()
except OSError: except OSError:
_msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?' _msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
import logging import logging
...@@ -60,8 +59,7 @@ def send(command, data): ...@@ -60,8 +59,7 @@ def send(command, data):
""" """
global _lock global _lock
try: try:
if multi_thread_enabled(): _lock.acquire()
_lock.acquire()
data = data.encode('utf8') data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long' assert len(data) < 1000000, 'Command too long'
msg = b'%b%06d%b' % (command.value, len(data), data) msg = b'%b%06d%b' % (command.value, len(data), data)
...@@ -69,8 +67,7 @@ def send(command, data): ...@@ -69,8 +67,7 @@ def send(command, data):
_out_file.write(msg) _out_file.write(msg)
_out_file.flush() _out_file.flush()
finally: finally:
if multi_thread_enabled(): _lock.release()
_lock.release()
def receive(): def receive():
......
...@@ -73,11 +73,12 @@ class AssessorTestCase(TestCase): ...@@ -73,11 +73,12 @@ class AssessorTestCase(TestCase):
assessor = NaiveAssessor() assessor = NaiveAssessor()
dispatcher = MsgDispatcher(None, assessor) dispatcher = MsgDispatcher(None, assessor)
try: nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False
dispatcher.run()
except Exception as e: dispatcher.run()
self.assertIs(type(e), AssertionError) e = dispatcher.worker_exceptions[0]
self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob') self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob')
self.assertEqual(_trials, ['A', 'B', 'A']) self.assertEqual(_trials, ['A', 'B', 'A'])
self.assertEqual(_end_trials, [('A', False), ('B', True)]) self.assertEqual(_end_trials, [('A', False), ('B', True)])
...@@ -90,4 +91,4 @@ class AssessorTestCase(TestCase): ...@@ -90,4 +91,4 @@ class AssessorTestCase(TestCase):
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
...@@ -88,11 +88,12 @@ class TunerTestCase(TestCase): ...@@ -88,11 +88,12 @@ class TunerTestCase(TestCase):
tuner = NaiveTuner() tuner = NaiveTuner()
dispatcher = MsgDispatcher(tuner) dispatcher = MsgDispatcher(tuner)
try: nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False
dispatcher.run()
except Exception as e: dispatcher.run()
self.assertIs(type(e), AssertionError) e = dispatcher.worker_exceptions[0]
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob') self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')
_reverse_io() # now we are receiving from Tuner's outgoing stream _reverse_io() # now we are receiving from Tuner's outgoing stream
self._assert_params(0, 2, [ ], None) self._assert_params(0, 2, [ ], None)
......
...@@ -76,8 +76,8 @@ def run(dispatch_type): ...@@ -76,8 +76,8 @@ def run(dispatch_type):
dipsatcher_list = TUNER_LIST if dispatch_type == 'Tuner' else ASSESSOR_LIST dipsatcher_list = TUNER_LIST if dispatch_type == 'Tuner' else ASSESSOR_LIST
for dispatcher_name in dipsatcher_list: for dispatcher_name in dipsatcher_list:
try: try:
# sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict # Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(5) time.sleep(6)
test_builtin_dispatcher(dispatch_type, dispatcher_name) test_builtin_dispatcher(dispatch_type, dispatcher_name)
print(GREEN + 'Test %s %s: TEST PASS' % (dispatcher_name, dispatch_type) + CLEAR) print(GREEN + 'Test %s %s: TEST PASS' % (dispatcher_name, dispatch_type) + CLEAR)
except Exception as error: except Exception as error:
......
...@@ -58,6 +58,7 @@ Optional('tuner'): Or({ ...@@ -58,6 +58,7 @@ Optional('tuner'): Or({
Optional('classArgs'): { Optional('classArgs'): {
'optimize_mode': Or('maximize', 'minimize') 'optimize_mode': Or('maximize', 'minimize')
}, },
Optional('includeIntermediateResults'): bool,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999), Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{ },{
'builtinTunerName': Or('BatchTuner', 'GridSearch'), 'builtinTunerName': Or('BatchTuner', 'GridSearch'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment