"ts/webui/src/static/model/trial.ts" did not exist on "ac232be520d1047eda64d1481fe0a42ef0e4580f"
Unverified Commit 63697ec5 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Route tuner and assessor commands to 2 seperate queues (#891)

1. Route tuner and assessor commands to 2 seperate queues  issue #841
2. Allow tuner to leverage intermediate result when trial is early stopped.  issue #843
parent c297650a
......@@ -231,12 +231,17 @@ machineList:
* __classArgs__
__classArgs__ specifies the arguments of tuner algorithm.
* __gpuNum__
__gpuNum__ specifies the gpu number to run the tuner process. The value of this field should be a positive number.
Note: users could only specify one way to set tuner, for example, set {tunerName, optimizationMode} or {tunerCommand, tunerCwd}, and could not set them both.
* __includeIntermediateResults__
If __includeIntermediateResults__ is true, the last intermediate result of the trial that is early stopped by assessor is sent to tuner as final result. The default value of __includeIntermediateResults__ is false.
* __assessor__
* Description
......
......@@ -46,6 +46,7 @@ interface ExperimentParams {
classFileName?: string;
checkpointDir: string;
gpuNum?: number;
includeIntermediateResults?: boolean;
};
assessor?: {
className: string;
......
......@@ -277,11 +277,17 @@ class NNIManager implements Manager {
newCwd = cwd;
}
// TO DO: add CUDA_VISIBLE_DEVICES
let includeIntermediateResultsEnv: boolean | undefined = false;
if (this.experimentProfile.params.tuner !== undefined) {
includeIntermediateResultsEnv = this.experimentProfile.params.tuner.includeIntermediateResults;
}
let nniEnv = {
NNI_MODE: mode,
NNI_CHECKPOINT_DIRECTORY: dataDirectory,
NNI_LOG_DIRECTORY: getLogDir(),
NNI_LOG_LEVEL: getLogLevel()
NNI_LOG_LEVEL: getLogLevel(),
NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv
};
let newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = spawn(command, [], {
......
......@@ -106,7 +106,7 @@ describe('core/ipcInterface.terminate', (): void => {
assert.ok(!procError);
deferred.resolve();
},
2000);
5000);
return deferred.promise;
});
......
......@@ -159,7 +159,8 @@ export namespace ValidationSchemas {
className: joi.string(),
classArgs: joi.any(),
gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow('')
checkpointDir: joi.string().allow(''),
includeIntermediateResults: joi.boolean()
}),
assessor: joi.object({
builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'),
......
......@@ -18,14 +18,15 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
import logging
from collections import defaultdict
import json_tricks
import threading
from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
from .common import multi_thread_enabled
_logger = logging.getLogger(__name__)
......@@ -70,7 +71,7 @@ def _pack_parameter(parameter_id, params, customized=False):
class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super().__init__()
super(MsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
......@@ -87,9 +88,8 @@ class MsgDispatcher(MsgDispatcherBase):
self.assessor.save_checkpoint()
def handle_initialize(self, data):
'''
data is search space
'''
"""Data is search space
"""
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True
......@@ -126,12 +126,7 @@ class MsgDispatcher(MsgDispatcherBase):
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
if data['type'] == 'FINAL':
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
self._handle_final_metric_data(data)
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
......@@ -157,7 +152,19 @@ class MsgDispatcher(MsgDispatcherBase):
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
return True
def _handle_final_metric_data(self, data):
"""Call tuner to process final results
"""
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results
"""
if data['type'] != 'PERIODICAL':
return True
if self.assessor is None:
......@@ -187,5 +194,20 @@ class MsgDispatcher(MsgDispatcherBase):
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS'))
if os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS') == 'true':
self._earlystop_notify_tuner(data)
else:
_logger.debug('GOOD')
def _earlystop_notify_tuner(self, data):
"""Send last intermediate result as final result to tuner in case the
trial is early stopped.
"""
_logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = 'FINAL'
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
self.enqueue_command(CommandType.ReportMetricData, data)
......@@ -19,14 +19,13 @@
# ==================================================================================================
#import json_tricks
import logging
import os
from queue import Queue
import sys
import threading
import logging
from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty
import json_tricks
from .common import init_logger, multi_thread_enabled
from .recoverable import Recoverable
from .protocol import CommandType, receive
......@@ -34,57 +33,109 @@ from .protocol import CommandType, receive
init_logger('dispatcher.log')
_logger = logging.getLogger(__name__)
QUEUE_LEN_WARNING_MARK = 20
_worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable):
def __init__(self):
if multi_thread_enabled():
self.pool = ThreadPool()
self.thread_results = []
else:
self.stopping = False
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,))
self.default_worker.start()
self.assessor_worker.start()
self.worker_exceptions = []
def run(self):
"""Run the tuner.
This function will never return unless raise.
"""
_logger.info('Start dispatcher')
mode = os.getenv('NNI_MODE')
if mode == 'resume':
self.load_checkpoint()
while True:
_logger.debug('waiting receive_message')
command, data = receive()
if data:
data = json_tricks.loads(data)
if command is None or command is CommandType.Terminate:
break
if multi_thread_enabled():
result = self.pool.map_async(self.handle_request_thread, [(command, data)])
result = self.pool.map_async(self.process_command_thread, [(command, data)])
self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]):
_logger.debug('Caught thread exception')
break
else:
self.handle_request((command, data))
self.enqueue_command(command, data)
_logger.info('Dispatcher exiting...')
self.stopping = True
if multi_thread_enabled():
self.pool.close()
self.pool.join()
else:
self.default_worker.join()
self.assessor_worker.join()
_logger.info('Terminated by NNI manager')
def handle_request_thread(self, request):
def command_queue_worker(self, command_queue):
"""Process commands in command queues.
"""
while True:
try:
# set timeout to ensure self.stopping is checked periodically
command, data = command_queue.get(timeout=3)
try:
self.process_command(command, data)
except Exception as e:
_logger.exception(e)
self.worker_exceptions.append(e)
break
except Empty:
pass
if self.stopping and (_worker_fast_exit_on_terminate or command_queue.empty()):
break
def enqueue_command(self, command, data):
"""Enqueue command into command queues
"""
if command == CommandType.TrialEnd or (command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'):
self.assessor_command_queue.put((command, data))
else:
self.default_command_queue.put((command, data))
qsize = self.default_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('default queue length: %d', qsize)
qsize = self.assessor_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('assessor queue length: %d', qsize)
def process_command_thread(self, request):
"""Worker thread to process a command.
"""
command, data = request
if multi_thread_enabled():
try:
self.handle_request(request)
self.process_command(command, data)
except Exception as e:
_logger.exception(str(e))
raise
else:
pass
def handle_request(self, request):
command, data = request
_logger.debug('handle request: command: [{}], data: [{}]'.format(command, data))
if data:
data = json_tricks.loads(data)
def process_command(self, command, data):
_logger.debug('process_command: command: [{}], data: [{}]'.format(command, data))
command_handlers = {
# Tunner commands:
......
......@@ -75,7 +75,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
class MultiPhaseMsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super()
super(MultiPhaseMsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
......
......@@ -42,11 +42,10 @@ class CommandType(Enum):
NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI'
_lock = threading.Lock()
try:
_in_file = open(3, 'rb')
_out_file = open(4, 'wb')
_lock = threading.Lock()
except OSError:
_msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
import logging
......@@ -60,7 +59,6 @@ def send(command, data):
"""
global _lock
try:
if multi_thread_enabled():
_lock.acquire()
data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long'
......@@ -69,7 +67,6 @@ def send(command, data):
_out_file.write(msg)
_out_file.flush()
finally:
if multi_thread_enabled():
_lock.release()
......
......@@ -73,9 +73,10 @@ class AssessorTestCase(TestCase):
assessor = NaiveAssessor()
dispatcher = MsgDispatcher(None, assessor)
try:
nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False
dispatcher.run()
except Exception as e:
e = dispatcher.worker_exceptions[0]
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob')
......
......@@ -88,9 +88,10 @@ class TunerTestCase(TestCase):
tuner = NaiveTuner()
dispatcher = MsgDispatcher(tuner)
try:
nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False
dispatcher.run()
except Exception as e:
e = dispatcher.worker_exceptions[0]
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')
......
......@@ -76,8 +76,8 @@ def run(dispatch_type):
dipsatcher_list = TUNER_LIST if dispatch_type == 'Tuner' else ASSESSOR_LIST
for dispatcher_name in dipsatcher_list:
try:
# sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(5)
# Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(6)
test_builtin_dispatcher(dispatch_type, dispatcher_name)
print(GREEN + 'Test %s %s: TEST PASS' % (dispatcher_name, dispatch_type) + CLEAR)
except Exception as error:
......
......@@ -58,6 +58,7 @@ Optional('tuner'): Or({
Optional('classArgs'): {
'optimize_mode': Or('maximize', 'minimize')
},
Optional('includeIntermediateResults'): bool,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'builtinTunerName': Or('BatchTuner', 'GridSearch'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment