"vscode:/vscode.git/clone" did not exist on "cf95cfc07cbc302d38ea6409eab8146f10cf200b"
Unverified Commit a5d614de authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Asynchronous dispatcher (#372)

* Asynchronous dispatcher

* updates

* updates

* updates

* updates
parent 8d63b108
...@@ -34,6 +34,7 @@ interface ExperimentParams { ...@@ -34,6 +34,7 @@ interface ExperimentParams {
searchSpace: string; searchSpace: string;
trainingServicePlatform: string; trainingServicePlatform: string;
multiPhase?: boolean; multiPhase?: boolean;
multiThread?: boolean;
tuner: { tuner: {
className: string; className: string;
builtinTunerName?: string; builtinTunerName?: string;
......
...@@ -158,12 +158,16 @@ function parseArg(names: string[]): string { ...@@ -158,12 +158,16 @@ function parseArg(names: string[]): string {
* @param assessor: similiar as tuner * @param assessor: similiar as tuner
* *
*/ */
function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false): string { function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false, multiThread: boolean = false): string {
let command: string = `python3 -m nni --tuner_class_name ${tuner.className}`; let command: string = `python3 -m nni --tuner_class_name ${tuner.className}`;
if (multiPhase) { if (multiPhase) {
command += ' --multi_phase'; command += ' --multi_phase';
} }
if (multiThread) {
command += ' --multi_thread';
}
if (tuner.classArgs !== undefined) { if (tuner.classArgs !== undefined) {
command += ` --tuner_args ${JSON.stringify(JSON.stringify(tuner.classArgs))}`; command += ` --tuner_args ${JSON.stringify(JSON.stringify(tuner.classArgs))}`;
} }
......
...@@ -26,6 +26,7 @@ const ADD_CUSTOMIZED_TRIAL_JOB = 'AD'; ...@@ -26,6 +26,7 @@ const ADD_CUSTOMIZED_TRIAL_JOB = 'AD';
const TRIAL_END = 'EN'; const TRIAL_END = 'EN';
const TERMINATE = 'TE'; const TERMINATE = 'TE';
const INITIALIZED = 'ID';
const NEW_TRIAL_JOB = 'TR'; const NEW_TRIAL_JOB = 'TR';
const SEND_TRIAL_JOB_PARAMETER = 'SP'; const SEND_TRIAL_JOB_PARAMETER = 'SP';
const NO_MORE_TRIAL_JOBS = 'NO'; const NO_MORE_TRIAL_JOBS = 'NO';
...@@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([ ...@@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
ADD_CUSTOMIZED_TRIAL_JOB, ADD_CUSTOMIZED_TRIAL_JOB,
TERMINATE, TERMINATE,
INITIALIZED,
NEW_TRIAL_JOB, NEW_TRIAL_JOB,
SEND_TRIAL_JOB_PARAMETER, SEND_TRIAL_JOB_PARAMETER,
NO_MORE_TRIAL_JOBS NO_MORE_TRIAL_JOBS
...@@ -61,6 +63,7 @@ export { ...@@ -61,6 +63,7 @@ export {
ADD_CUSTOMIZED_TRIAL_JOB, ADD_CUSTOMIZED_TRIAL_JOB,
TRIAL_END, TRIAL_END,
TERMINATE, TERMINATE,
INITIALIZED,
NEW_TRIAL_JOB, NEW_TRIAL_JOB,
NO_MORE_TRIAL_JOBS, NO_MORE_TRIAL_JOBS,
KILL_TRIAL_JOB, KILL_TRIAL_JOB,
......
...@@ -37,8 +37,8 @@ import { ...@@ -37,8 +37,8 @@ import {
} from '../common/trainingService'; } from '../common/trainingService';
import { delay, getLogDir, getMsgDispatcherCommand } from '../common/utils'; import { delay, getLogDir, getMsgDispatcherCommand } from '../common/utils';
import { import {
ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA, ADD_CUSTOMIZED_TRIAL_JOB, INITIALIZE, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS,
REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
} from './commands'; } from './commands';
import { createDispatcherInterface, IpcInterface } from './ipcInterface'; import { createDispatcherInterface, IpcInterface } from './ipcInterface';
...@@ -127,7 +127,8 @@ class NNIManager implements Manager { ...@@ -127,7 +127,8 @@ class NNIManager implements Manager {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString()); this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
} }
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase); const dispatcherCommand: string = getMsgDispatcherCommand(
expParams.tuner, expParams.assessor, expParams.multiPhase, expParams.multiThread);
this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.log.debug(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner( this.setupTuner(
//expParams.tuner.tunerCommand, //expParams.tuner.tunerCommand,
...@@ -159,7 +160,8 @@ class NNIManager implements Manager { ...@@ -159,7 +160,8 @@ class NNIManager implements Manager {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString()); this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
} }
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase); const dispatcherCommand: string = getMsgDispatcherCommand(
expParams.tuner, expParams.assessor, expParams.multiPhase, expParams.multiThread);
this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.log.debug(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner( this.setupTuner(
dispatcherCommand, dispatcherCommand,
...@@ -419,16 +421,20 @@ class NNIManager implements Manager { ...@@ -419,16 +421,20 @@ class NNIManager implements Manager {
} else { } else {
this.trialConcurrencyChange = requestTrialNum; this.trialConcurrencyChange = requestTrialNum;
} }
for (let i: number = 0; i < requestTrialNum; i++) {
const requestCustomTrialNum: number = Math.min(requestTrialNum, this.customizedTrials.length);
for (let i: number = 0; i < requestCustomTrialNum; i++) {
// ask tuner for more trials // ask tuner for more trials
if (this.customizedTrials.length > 0) { if (this.customizedTrials.length > 0) {
const hyperParams: string | undefined = this.customizedTrials.shift(); const hyperParams: string | undefined = this.customizedTrials.shift();
this.dispatcher.sendCommand(ADD_CUSTOMIZED_TRIAL_JOB, hyperParams); this.dispatcher.sendCommand(ADD_CUSTOMIZED_TRIAL_JOB, hyperParams);
} else {
this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, '1');
} }
} }
if (requestTrialNum - requestCustomTrialNum > 0) {
this.requestTrialJobs(requestTrialNum - requestCustomTrialNum);
}
// check maxtrialnum and maxduration here // check maxtrialnum and maxduration here
if (this.experimentProfile.execDuration > this.experimentProfile.params.maxExecDuration || if (this.experimentProfile.execDuration > this.experimentProfile.params.maxExecDuration ||
this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
...@@ -526,11 +532,9 @@ class NNIManager implements Manager { ...@@ -526,11 +532,9 @@ class NNIManager implements Manager {
if (this.dispatcher === undefined) { if (this.dispatcher === undefined) {
throw new Error('Dispatcher error: tuner has not been setup'); throw new Error('Dispatcher error: tuner has not been setup');
} }
// TO DO: we should send INITIALIZE command to tuner if user's tuner needs to run init method in tuner this.log.debug(`Send tuner command: INITIALIZE: ${this.experimentProfile.params.searchSpace}`);
this.log.debug(`Send tuner command: update search space: ${this.experimentProfile.params.searchSpace}`); // Tuner need to be initialized with search space before generating any hyper parameters
this.dispatcher.sendCommand(UPDATE_SEARCH_SPACE, this.experimentProfile.params.searchSpace); this.dispatcher.sendCommand(INITIALIZE, this.experimentProfile.params.searchSpace);
this.log.debug(`Send tuner command: ${this.experimentProfile.params.trialConcurrency}`);
this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, String(this.experimentProfile.params.trialConcurrency));
} }
private async onTrialJobMetrics(metric: TrialJobMetric): Promise<void> { private async onTrialJobMetrics(metric: TrialJobMetric): Promise<void> {
...@@ -541,9 +545,32 @@ class NNIManager implements Manager { ...@@ -541,9 +545,32 @@ class NNIManager implements Manager {
this.dispatcher.sendCommand(REPORT_METRIC_DATA, metric.data); this.dispatcher.sendCommand(REPORT_METRIC_DATA, metric.data);
} }
private requestTrialJobs(jobNum: number): void {
if (jobNum < 1) {
return;
}
if (this.dispatcher === undefined) {
throw new Error('Dispatcher error: tuner has not been setup');
}
if (this.experimentProfile.params.multiThread) {
// Send multiple requests to ensure multiple hyper parameters are generated in non-blocking way.
// For a single REQUEST_TRIAL_JOBS request, hyper parameters are generated one by one
// sequentially.
for (let i: number = 0; i < jobNum; i++) {
this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, '1');
}
} else {
this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, String(jobNum));
}
}
private async onTunerCommand(commandType: string, content: string): Promise<void> { private async onTunerCommand(commandType: string, content: string): Promise<void> {
this.log.info(`Command from tuner: ${commandType}, ${content}`); this.log.info(`Command from tuner: ${commandType}, ${content}`);
switch (commandType) { switch (commandType) {
case INITIALIZED:
// Tuner is intialized, search space is set, request tuner to generate hyper parameters
this.requestTrialJobs(this.experimentProfile.params.trialConcurrency);
break;
case NEW_TRIAL_JOB: case NEW_TRIAL_JOB:
this.waitingTrials.push(content); this.waitingTrials.push(content);
break; break;
......
...@@ -68,6 +68,7 @@ export namespace ValidationSchemas { ...@@ -68,6 +68,7 @@ export namespace ValidationSchemas {
searchSpace: joi.string().required(), searchSpace: joi.string().required(),
maxExecDuration: joi.number().min(0).required(), maxExecDuration: joi.number().min(0).required(),
multiPhase: joi.boolean(), multiPhase: joi.boolean(),
multiThread: joi.boolean(),
tuner: joi.object({ tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch'), builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch'),
codeDir: joi.string(), codeDir: joi.string(),
......
...@@ -28,6 +28,7 @@ import json ...@@ -28,6 +28,7 @@ import json
import importlib import importlib
from .constants import ModuleName, ClassName, ClassArgs from .constants import ModuleName, ClassName, ClassArgs
from nni.common import enable_multi_thread
from nni.msg_dispatcher import MsgDispatcher from nni.msg_dispatcher import MsgDispatcher
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
logger = logging.getLogger('nni.main') logger = logging.getLogger('nni.main')
...@@ -91,6 +92,7 @@ def parse_args(): ...@@ -91,6 +92,7 @@ def parse_args():
parser.add_argument('--assessor_class_filename', type=str, required=False, parser.add_argument('--assessor_class_filename', type=str, required=False,
help='Assessor class file path') help='Assessor class file path')
parser.add_argument('--multi_phase', action='store_true') parser.add_argument('--multi_phase', action='store_true')
parser.add_argument('--multi_thread', action='store_true')
flags, _ = parser.parse_known_args() flags, _ = parser.parse_known_args()
return flags return flags
...@@ -101,6 +103,8 @@ def main(): ...@@ -101,6 +103,8 @@ def main():
''' '''
args = parse_args() args = parse_args()
if args.multi_thread:
enable_multi_thread()
tuner = None tuner = None
assessor = None assessor = None
......
...@@ -78,3 +78,12 @@ def init_logger(logger_file_path): ...@@ -78,3 +78,12 @@ def init_logger(logger_file_path):
logging.getLogger('matplotlib').setLevel(logging.INFO) logging.getLogger('matplotlib').setLevel(logging.INFO)
sys.stdout = _LoggerFileWrapper(logger_file) sys.stdout = _LoggerFileWrapper(logger_file)
_multi_thread = False
def enable_multi_thread():
global _multi_thread
_multi_thread = True
def multi_thread_enabled():
return _multi_thread
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
import logging import logging
from collections import defaultdict from collections import defaultdict
import json_tricks import json_tricks
import threading
from .protocol import CommandType, send from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase from .msg_dispatcher_base import MsgDispatcherBase
...@@ -69,7 +70,7 @@ def _pack_parameter(parameter_id, params, customized=False): ...@@ -69,7 +70,7 @@ def _pack_parameter(parameter_id, params, customized=False):
class MsgDispatcher(MsgDispatcherBase): class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None): def __init__(self, tuner, assessor=None):
super() super().__init__()
self.tuner = tuner self.tuner = tuner
self.assessor = assessor self.assessor = assessor
if assessor is None: if assessor is None:
...@@ -85,6 +86,14 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -85,6 +86,14 @@ class MsgDispatcher(MsgDispatcherBase):
if self.assessor is not None: if self.assessor is not None:
self.assessor.save_checkpoint() self.assessor.save_checkpoint()
def handle_initialize(self, data):
'''
data is search space
'''
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
# data: number or trial jobs # data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)] ids = [_create_parameter_id() for _ in range(data)]
...@@ -127,7 +136,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -127,7 +136,7 @@ class MsgDispatcher(MsgDispatcherBase):
if self.assessor is not None: if self.assessor is not None:
self._handle_intermediate_metric_data(data) self._handle_intermediate_metric_data(data)
else: else:
pass pass
else: else:
raise ValueError('Data type not supported: {}'.format(data['type'])) raise ValueError('Data type not supported: {}'.format(data['type']))
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
import os import os
import logging import logging
import json_tricks import json_tricks
from multiprocessing.dummy import Pool as ThreadPool
from .common import init_logger from .common import init_logger, multi_thread_enabled
from .recoverable import Recoverable from .recoverable import Recoverable
from .protocol import CommandType, receive from .protocol import CommandType, receive
...@@ -31,6 +31,10 @@ init_logger('dispatcher.log') ...@@ -31,6 +31,10 @@ init_logger('dispatcher.log')
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class MsgDispatcherBase(Recoverable): class MsgDispatcherBase(Recoverable):
def __init__(self):
if multi_thread_enabled():
self.pool = ThreadPool()
def run(self): def run(self):
"""Run the tuner. """Run the tuner.
This function will never return unless raise. This function will never return unless raise.
...@@ -39,17 +43,24 @@ class MsgDispatcherBase(Recoverable): ...@@ -39,17 +43,24 @@ class MsgDispatcherBase(Recoverable):
if mode == 'resume': if mode == 'resume':
self.load_checkpoint() self.load_checkpoint()
while self.handle_request(): while True:
pass _logger.debug('waiting receive_message')
command, data = receive()
if command is None:
break
if multi_thread_enabled():
self.pool.map_async(self.handle_request, [(command, data)])
else:
self.handle_request((command, data))
_logger.info('Terminated by NNI manager') if multi_thread_enabled():
self.pool.close()
self.pool.join()
def handle_request(self): _logger.info('Terminated by NNI manager')
_logger.debug('waiting receive_message')
command, data = receive() def handle_request(self, request):
if command is None: command, data = request
return False
_logger.debug('handle request: command: [{}], data: [{}]'.format(command, data)) _logger.debug('handle request: command: [{}], data: [{}]'.format(command, data))
...@@ -60,6 +71,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -60,6 +71,7 @@ class MsgDispatcherBase(Recoverable):
command_handlers = { command_handlers = {
# Tunner commands: # Tunner commands:
CommandType.Initialize: self.handle_initialize,
CommandType.RequestTrialJobs: self.handle_request_trial_jobs, CommandType.RequestTrialJobs: self.handle_request_trial_jobs,
CommandType.UpdateSearchSpace: self.handle_update_search_space, CommandType.UpdateSearchSpace: self.handle_update_search_space,
CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial, CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial,
...@@ -74,6 +86,9 @@ class MsgDispatcherBase(Recoverable): ...@@ -74,6 +86,9 @@ class MsgDispatcherBase(Recoverable):
return command_handlers[command](data) return command_handlers[command](data)
def handle_initialize(self, data):
raise NotImplementedError('handle_initialize not implemented')
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
raise NotImplementedError('handle_request_trial_jobs not implemented') raise NotImplementedError('handle_request_trial_jobs not implemented')
......
...@@ -91,6 +91,14 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase): ...@@ -91,6 +91,14 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
if self.assessor is not None: if self.assessor is not None:
self.assessor.save_checkpoint() self.assessor.save_checkpoint()
def handle_initialize(self, data):
'''
data is search space
'''
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
# data: number or trial jobs # data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)] ids = [_create_parameter_id() for _ in range(data)]
......
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
# ================================================================================================== # ==================================================================================================
import logging import logging
import threading
from enum import Enum from enum import Enum
from .common import multi_thread_enabled
class CommandType(Enum): class CommandType(Enum):
...@@ -33,6 +35,7 @@ class CommandType(Enum): ...@@ -33,6 +35,7 @@ class CommandType(Enum):
Terminate = b'TE' Terminate = b'TE'
# out # out
Initialized = b'ID'
NewTrialJob = b'TR' NewTrialJob = b'TR'
SendTrialJobParameter = b'SP' SendTrialJobParameter = b'SP'
NoMoreTrialJobs = b'NO' NoMoreTrialJobs = b'NO'
...@@ -42,6 +45,7 @@ class CommandType(Enum): ...@@ -42,6 +45,7 @@ class CommandType(Enum):
try: try:
_in_file = open(3, 'rb') _in_file = open(3, 'rb')
_out_file = open(4, 'wb') _out_file = open(4, 'wb')
_lock = threading.Lock()
except OSError: except OSError:
_msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?' _msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
import logging import logging
...@@ -53,12 +57,19 @@ def send(command, data): ...@@ -53,12 +57,19 @@ def send(command, data):
command: CommandType object. command: CommandType object.
data: string payload. data: string payload.
""" """
data = data.encode('utf8') global _lock
assert len(data) < 1000000, 'Command too long' try:
msg = b'%b%06d%b' % (command.value, len(data), data) if multi_thread_enabled():
logging.getLogger(__name__).debug('Sending command, data: [%s]' % msg) _lock.acquire()
_out_file.write(msg) data = data.encode('utf8')
_out_file.flush() assert len(data) < 1000000, 'Command too long'
msg = b'%b%06d%b' % (command.value, len(data), data)
logging.getLogger(__name__).debug('Sending command, data: [%s]' % msg)
_out_file.write(msg)
_out_file.flush()
finally:
if multi_thread_enabled():
_lock.release()
def receive(): def receive():
......
...@@ -31,6 +31,7 @@ Optional('maxTrialNum'): And(int, lambda x: 1 <= x <= 99999), ...@@ -31,6 +31,7 @@ Optional('maxTrialNum'): And(int, lambda x: 1 <= x <= 99999),
'trainingServicePlatform': And(str, lambda x: x in ['remote', 'local', 'pai', 'kubeflow']), 'trainingServicePlatform': And(str, lambda x: x in ['remote', 'local', 'pai', 'kubeflow']),
Optional('searchSpacePath'): os.path.exists, Optional('searchSpacePath'): os.path.exists,
Optional('multiPhase'): bool, Optional('multiPhase'): bool,
Optional('multiThread'): bool,
'useAnnotation': bool, 'useAnnotation': bool,
'tuner': Or({ 'tuner': Or({
'builtinTunerName': Or('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch'), 'builtinTunerName': Or('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch'),
......
...@@ -196,6 +196,8 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -196,6 +196,8 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['description'] = experiment_config['description'] request_data['description'] = experiment_config['description']
if experiment_config.get('multiPhase'): if experiment_config.get('multiPhase'):
request_data['multiPhase'] = experiment_config.get('multiPhase') request_data['multiPhase'] = experiment_config.get('multiPhase')
if experiment_config.get('multiThread'):
request_data['multiThread'] = experiment_config.get('multiThread')
request_data['tuner'] = experiment_config['tuner'] request_data['tuner'] = experiment_config['tuner']
if 'assessor' in experiment_config: if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor'] request_data['assessor'] = experiment_config['assessor']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment