Commit 9d01d083 authored by chicm-ms's avatar chicm-ms Committed by Yan Ni
Browse files

Refactor dispatcher cmdline (#1862)

parent 2a81b08c
...@@ -27,7 +27,7 @@ interface ExperimentParams { ...@@ -27,7 +27,7 @@ interface ExperimentParams {
versionCheck?: boolean; versionCheck?: boolean;
logCollection?: string; logCollection?: string;
tuner?: { tuner?: {
className: string; className?: string;
builtinTunerName?: string; builtinTunerName?: string;
codeDir?: string; codeDir?: string;
classArgs?: any; classArgs?: any;
...@@ -37,7 +37,7 @@ interface ExperimentParams { ...@@ -37,7 +37,7 @@ interface ExperimentParams {
gpuIndices?: string; gpuIndices?: string;
}; };
assessor?: { assessor?: {
className: string; className?: string;
builtinAssessorName?: string; builtinAssessorName?: string;
codeDir?: string; codeDir?: string;
classArgs?: any; classArgs?: any;
...@@ -45,7 +45,7 @@ interface ExperimentParams { ...@@ -45,7 +45,7 @@ interface ExperimentParams {
checkpointDir: string; checkpointDir: string;
}; };
advisor?: { advisor?: {
className: string; className?: string;
builtinAdvisorName?: string; builtinAdvisorName?: string;
codeDir?: string; codeDir?: string;
classArgs?: any; classArgs?: any;
......
...@@ -17,7 +17,7 @@ import * as util from 'util'; ...@@ -17,7 +17,7 @@ import * as util from 'util';
import { Database, DataStore } from './datastore'; import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo'; import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
import { Manager } from './manager'; import { ExperimentParams, Manager } from './manager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService'; import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
function getExperimentRootDir(): string { function getExperimentRootDir(): string {
...@@ -130,15 +130,6 @@ function parseArg(names: string[]): string { ...@@ -130,15 +130,6 @@ function parseArg(names: string[]): string {
return ''; return '';
} }
function encodeCmdLineArgs(args: any): any {
if(process.platform === 'win32'){
return JSON.stringify(args);
}
else{
return JSON.stringify(JSON.stringify(args));
}
}
function getCmdPy(): string { function getCmdPy(): string {
let cmd = 'python3'; let cmd = 'python3';
if(process.platform === 'win32'){ if(process.platform === 'win32'){
...@@ -150,83 +141,14 @@ function getCmdPy(): string { ...@@ -150,83 +141,14 @@ function getCmdPy(): string {
/** /**
* Generate command line to start automl algorithm(s), * Generate command line to start automl algorithm(s),
* either start advisor or start a process which runs tuner and assessor * either start advisor or start a process which runs tuner and assessor
* @param tuner : For builtin tuner:
* {
* className: 'EvolutionTuner'
* classArgs: {
* optimize_mode: 'maximize',
* population_size: 3
* }
* }
* customized:
* {
* codeDir: '/tmp/mytuner'
* classFile: 'best_tuner.py'
* className: 'BestTuner'
* classArgs: {
* optimize_mode: 'maximize',
* population_size: 3
* }
* }
* *
* @param assessor: similiar as tuner * @param expParams: experiment startup parameters
* @param advisor: similar as tuner
* *
*/ */
function getMsgDispatcherCommand(tuner: any, assessor: any, advisor: any, multiPhase: boolean = false, multiThread: boolean = false): string { function getMsgDispatcherCommand(expParams: ExperimentParams): string {
if ((tuner || assessor) && advisor) { const clonedParams = Object.assign({}, expParams);
throw new Error('Error: specify both tuner/assessor and advisor is not allowed'); delete clonedParams.searchSpace;
} return `${getCmdPy()} -m nni --exp_params ${Buffer.from(JSON.stringify(clonedParams)).toString('base64')}`;
if (!tuner && !advisor) {
throw new Error('Error: specify neither tuner nor advisor is not allowed');
}
let command: string = `${getCmdPy()} -m nni`;
if (multiPhase) {
command += ' --multi_phase';
}
if (multiThread) {
command += ' --multi_thread';
}
if (advisor) {
command += ` --advisor_class_name ${advisor.className}`;
if (advisor.classArgs !== undefined) {
command += ` --advisor_args ${encodeCmdLineArgs(advisor.classArgs)}`;
}
if (advisor.codeDir !== undefined && advisor.codeDir.length > 1) {
command += ` --advisor_directory ${advisor.codeDir}`;
}
if (advisor.classFileName !== undefined && advisor.classFileName.length > 1) {
command += ` --advisor_class_filename ${advisor.classFileName}`;
}
} else {
command += ` --tuner_class_name ${tuner.className}`;
if (tuner.classArgs !== undefined) {
command += ` --tuner_args ${encodeCmdLineArgs(tuner.classArgs)}`;
}
if (tuner.codeDir !== undefined && tuner.codeDir.length > 1) {
command += ` --tuner_directory ${tuner.codeDir}`;
}
if (tuner.classFileName !== undefined && tuner.classFileName.length > 1) {
command += ` --tuner_class_filename ${tuner.classFileName}`;
}
if (assessor !== undefined && assessor.className !== undefined) {
command += ` --assessor_class_name ${assessor.className}`;
if (assessor.classArgs !== undefined) {
command += ` --assessor_args ${encodeCmdLineArgs(assessor.classArgs)}`;
}
if (assessor.codeDir !== undefined && assessor.codeDir.length > 1) {
command += ` --assessor_directory ${assessor.codeDir}`;
}
if (assessor.classFileName !== undefined && assessor.classFileName.length > 1) {
command += ` --assessor_class_filename ${assessor.classFileName}`;
}
}
}
return command;
} }
/** /**
......
...@@ -170,8 +170,7 @@ class NNIManager implements Manager { ...@@ -170,8 +170,7 @@ class NNIManager implements Manager {
this.trainingService.setClusterMetadata('log_collection', expParams.logCollection.toString()); this.trainingService.setClusterMetadata('log_collection', expParams.logCollection.toString());
} }
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.advisor, const dispatcherCommand: string = getMsgDispatcherCommand(expParams);
expParams.multiPhase, expParams.multiThread);
this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.log.debug(`dispatcher command: ${dispatcherCommand}`);
const checkpointDir: string = await this.createCheckpointDir(); const checkpointDir: string = await this.createCheckpointDir();
this.setupTuner( this.setupTuner(
...@@ -211,8 +210,7 @@ class NNIManager implements Manager { ...@@ -211,8 +210,7 @@ class NNIManager implements Manager {
this.trainingService.setClusterMetadata('version_check', expParams.versionCheck.toString()); this.trainingService.setClusterMetadata('version_check', expParams.versionCheck.toString());
} }
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.advisor, const dispatcherCommand: string = getMsgDispatcherCommand(expParams);
expParams.multiPhase, expParams.multiThread);
this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.log.debug(`dispatcher command: ${dispatcherCommand}`);
const checkpointDir: string = await this.createCheckpointDir(); const checkpointDir: string = await this.createCheckpointDir();
this.setupTuner( this.setupTuner(
......
...@@ -21,18 +21,26 @@ function startProcess(): void { ...@@ -21,18 +21,26 @@ function startProcess(): void {
const dispatcherCmd: string = getMsgDispatcherCommand( const dispatcherCmd: string = getMsgDispatcherCommand(
// Mock tuner config // Mock tuner config
{ {
className: 'DummyTuner', experimentName: 'exp1',
codeDir: './', maxExecDuration: 3600,
classFileName: 'dummy_tuner.py' searchSpace: '',
}, trainingServicePlatform: 'local',
// Mock assessor config authorName: '',
{ trialConcurrency: 1,
className: 'DummyAssessor', maxTrialNum: 5,
codeDir: './', tuner: {
classFileName: 'dummy_assessor.py' className: 'DummyTuner',
}, codeDir: './',
// advisor classFileName: 'dummy_tuner.py',
undefined checkpointDir: './'
},
assessor: {
className: 'DummyAssessor',
codeDir: './',
classFileName: 'dummy_assessor.py',
checkpointDir: './'
}
}
); );
const proc: ChildProcess = getTunerProc(dispatcherCmd, stdio, 'core/test', process.env); const proc: ChildProcess = getTunerProc(dispatcherCmd, stdio, 'core/test', process.env);
proc.on('error', (error: Error): void => { proc.on('error', (error: Error): void => {
......
...@@ -42,9 +42,9 @@ describe('Unit test for nnimanager', function () { ...@@ -42,9 +42,9 @@ describe('Unit test for nnimanager', function () {
maxExecDuration: 5, maxExecDuration: 5,
maxTrialNum: 3, maxTrialNum: 3,
trainingServicePlatform: 'local', trainingServicePlatform: 'local',
searchSpace: '{"x":1}', searchSpace: '{"lr": {"_type": "choice", "_value": [0.01,0.001]}}',
tuner: { tuner: {
className: 'TPE', builtinTunerName: 'TPE',
classArgs: { classArgs: {
optimize_mode: 'maximize' optimize_mode: 'maximize'
}, },
...@@ -52,7 +52,7 @@ describe('Unit test for nnimanager', function () { ...@@ -52,7 +52,7 @@ describe('Unit test for nnimanager', function () {
gpuNum: 0 gpuNum: 0
}, },
assessor: { assessor: {
className: 'Medianstop', builtinAssessorName: 'Medianstop',
checkpointDir: '', checkpointDir: '',
gpuNum: 1 gpuNum: 1
} }
...@@ -65,9 +65,9 @@ describe('Unit test for nnimanager', function () { ...@@ -65,9 +65,9 @@ describe('Unit test for nnimanager', function () {
maxExecDuration: 6, maxExecDuration: 6,
maxTrialNum: 2, maxTrialNum: 2,
trainingServicePlatform: 'local', trainingServicePlatform: 'local',
searchSpace: '{"y":2}', searchSpace: '{"lr": {"_type": "choice", "_value": [0.01,0.001]}}',
tuner: { tuner: {
className: 'TPE', builtinTunerName: 'TPE',
classArgs: { classArgs: {
optimize_mode: 'maximize' optimize_mode: 'maximize'
}, },
...@@ -75,7 +75,7 @@ describe('Unit test for nnimanager', function () { ...@@ -75,7 +75,7 @@ describe('Unit test for nnimanager', function () {
gpuNum: 0 gpuNum: 0
}, },
assessor: { assessor: {
className: 'Medianstop', builtinAssessorName: 'Medianstop',
checkpointDir: '', checkpointDir: '',
gpuNum: 1 gpuNum: 1
} }
...@@ -198,7 +198,7 @@ describe('Unit test for nnimanager', function () { ...@@ -198,7 +198,7 @@ describe('Unit test for nnimanager', function () {
it('test updateExperimentProfile SEARCH_SPACE', () => { it('test updateExperimentProfile SEARCH_SPACE', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'SEARCH_SPACE').then(() => { return nniManager.updateExperimentProfile(experimentProfile, 'SEARCH_SPACE').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => { nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.searchSpace).to.be.equal('{"y":2}'); expect(updateProfile.params.searchSpace).to.be.equal('{"lr": {"_type": "choice", "_value": [0.01,0.001]}}');
}); });
}).catch((error) => { }).catch((error) => {
assert.fail(error); assert.fail(error);
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
'''
__main__.py
'''
import os import os
import sys import sys
import argparse import argparse
import logging import logging
import json import json
import importlib import importlib
import base64
from .common import enable_multi_thread, enable_multi_phase from .common import enable_multi_thread, enable_multi_phase
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
...@@ -29,99 +27,67 @@ def augment_classargs(input_class_args, classname): ...@@ -29,99 +27,67 @@ def augment_classargs(input_class_args, classname):
input_class_args[key] = value input_class_args[key] = value
return input_class_args return input_class_args
def create_builtin_class_instance(classname, jsonstr_args, is_advisor=False):
if is_advisor: def create_builtin_class_instance(class_name, class_args, builtin_module_dict, builtin_class_dict):
if classname not in AdvisorModuleName or \ if class_name not in builtin_module_dict or \
importlib.util.find_spec(AdvisorModuleName[classname]) is None: importlib.util.find_spec(builtin_module_dict[class_name]) is None:
raise RuntimeError('Advisor module is not found: {}'.format(classname)) raise RuntimeError('Builtin module is not found: {}'.format(class_name))
class_module = importlib.import_module(AdvisorModuleName[classname]) class_module = importlib.import_module(builtin_module_dict[class_name])
class_constructor = getattr(class_module, AdvisorClassName[classname]) class_constructor = getattr(class_module, builtin_class_dict[class_name])
else:
if classname not in ModuleName or \ if class_args is None:
importlib.util.find_spec(ModuleName[classname]) is None: class_args = {}
raise RuntimeError('Tuner module is not found: {}'.format(classname)) class_args = augment_classargs(class_args, class_name)
class_module = importlib.import_module(ModuleName[classname]) instance = class_constructor(**class_args)
class_constructor = getattr(class_module, ClassName[classname])
if jsonstr_args:
class_args = augment_classargs(json.loads(jsonstr_args), classname)
else:
class_args = augment_classargs({}, classname)
if class_args:
instance = class_constructor(**class_args)
else:
instance = class_constructor()
return instance return instance
def create_customized_class_instance(class_dir, class_filename, classname, jsonstr_args):
if not os.path.isfile(os.path.join(class_dir, class_filename)): def create_customized_class_instance(class_params):
code_dir = class_params.get('codeDir')
class_filename = class_params.get('classFileName')
class_name = class_params.get('className')
class_args = class_params.get('classArgs')
if not os.path.isfile(os.path.join(code_dir, class_filename)):
raise ValueError('Class file not found: {}'.format( raise ValueError('Class file not found: {}'.format(
os.path.join(class_dir, class_filename))) os.path.join(code_dir, class_filename)))
sys.path.append(class_dir) sys.path.append(code_dir)
module_name = os.path.splitext(class_filename)[0] module_name = os.path.splitext(class_filename)[0]
class_module = importlib.import_module(module_name) class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, classname) class_constructor = getattr(class_module, class_name)
if jsonstr_args:
class_args = json.loads(jsonstr_args) if class_args is None:
instance = class_constructor(**class_args) class_args = {}
else: instance = class_constructor(**class_args)
instance = class_constructor()
return instance return instance
def parse_args():
parser = argparse.ArgumentParser(description='parse command line parameters.')
parser.add_argument('--advisor_class_name', type=str, required=False,
help='Advisor class name, the class must be a subclass of nni.MsgDispatcherBase')
parser.add_argument('--advisor_class_filename', type=str, required=False,
help='Advisor class file path')
parser.add_argument('--advisor_args', type=str, required=False,
help='Parameters pass to advisor __init__ constructor')
parser.add_argument('--advisor_directory', type=str, required=False,
help='Advisor directory')
parser.add_argument('--tuner_class_name', type=str, required=False,
help='Tuner class name, the class must be a subclass of nni.Tuner')
parser.add_argument('--tuner_class_filename', type=str, required=False,
help='Tuner class file path')
parser.add_argument('--tuner_args', type=str, required=False,
help='Parameters pass to tuner __init__ constructor')
parser.add_argument('--tuner_directory', type=str, required=False,
help='Tuner directory')
parser.add_argument('--assessor_class_name', type=str, required=False,
help='Assessor class name, the class must be a subclass of nni.Assessor')
parser.add_argument('--assessor_args', type=str, required=False,
help='Parameters pass to assessor __init__ constructor')
parser.add_argument('--assessor_directory', type=str, required=False,
help='Assessor directory')
parser.add_argument('--assessor_class_filename', type=str, required=False,
help='Assessor class file path')
parser.add_argument('--multi_phase', action='store_true')
parser.add_argument('--multi_thread', action='store_true')
flags, _ = parser.parse_known_args()
return flags
def main(): def main():
''' parser = argparse.ArgumentParser(description='Dispatcher command line parser')
main function. parser.add_argument('--exp_params', type=str, required=True)
''' args, _ = parser.parse_known_args()
exp_params_decode = base64.b64decode(args.exp_params).decode('utf-8')
logger.debug('decoded exp_params: [%s]', exp_params_decode)
exp_params = json.loads(exp_params_decode)
logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4))
args = parse_args() if exp_params.get('multiThread'):
if args.multi_thread:
enable_multi_thread() enable_multi_thread()
if args.multi_phase: if exp_params.get('multiPhase'):
enable_multi_phase() enable_multi_phase()
if args.advisor_class_name: if exp_params.get('advisor') is not None:
# advisor is enabled and starts to run # advisor is enabled and starts to run
_run_advisor(args) _run_advisor(exp_params)
else: else:
# tuner (and assessor) is enabled and starts to run # tuner (and assessor) is enabled and starts to run
tuner = _create_tuner(args) assert exp_params.get('tuner') is not None
if args.assessor_class_name: tuner = _create_tuner(exp_params)
assessor = _create_assessor(args) if exp_params.get('assessor') is not None:
assessor = _create_assessor(exp_params)
else: else:
assessor = None assessor = None
dispatcher = MsgDispatcher(tuner, assessor) dispatcher = MsgDispatcher(tuner, assessor)
...@@ -139,17 +105,14 @@ def main(): ...@@ -139,17 +105,14 @@ def main():
raise raise
def _run_advisor(args): def _run_advisor(exp_params):
if args.advisor_class_name in AdvisorModuleName: if exp_params.get('advisor').get('builtinAdvisorName') in AdvisorModuleName:
dispatcher = create_builtin_class_instance( dispatcher = create_builtin_class_instance(
args.advisor_class_name, exp_params.get('advisor').get('builtinAdvisorName'),
args.advisor_args, True) exp_params.get('advisor').get('classArgs'),
AdvisorModuleName, AdvisorClassName)
else: else:
dispatcher = create_customized_class_instance( dispatcher = create_customized_class_instance(exp_params.get('advisor'))
args.advisor_directory,
args.advisor_class_filename,
args.advisor_class_name,
args.advisor_args)
if dispatcher is None: if dispatcher is None:
raise AssertionError('Failed to create Advisor instance') raise AssertionError('Failed to create Advisor instance')
try: try:
...@@ -159,33 +122,27 @@ def _run_advisor(args): ...@@ -159,33 +122,27 @@ def _run_advisor(args):
raise raise
def _create_tuner(args): def _create_tuner(exp_params):
if args.tuner_class_name in ModuleName: if exp_params.get('tuner').get('builtinTunerName') in ModuleName:
tuner = create_builtin_class_instance( tuner = create_builtin_class_instance(
args.tuner_class_name, exp_params.get('tuner').get('builtinTunerName'),
args.tuner_args) exp_params.get('tuner').get('classArgs'),
ModuleName, ClassName)
else: else:
tuner = create_customized_class_instance( tuner = create_customized_class_instance(exp_params.get('tuner'))
args.tuner_directory,
args.tuner_class_filename,
args.tuner_class_name,
args.tuner_args)
if tuner is None: if tuner is None:
raise AssertionError('Failed to create Tuner instance') raise AssertionError('Failed to create Tuner instance')
return tuner return tuner
def _create_assessor(args): def _create_assessor(exp_params):
if args.assessor_class_name in ModuleName: if exp_params.get('assessor').get('builtinAssessorName') in ModuleName:
assessor = create_builtin_class_instance( assessor = create_builtin_class_instance(
args.assessor_class_name, exp_params.get('assessor').get('builtinAssessorName'),
args.assessor_args) exp_params.get('assessor').get('classArgs'),
ModuleName, ClassName)
else: else:
assessor = create_customized_class_instance( assessor = create_customized_class_instance(exp_params.get('assessor'))
args.assessor_directory,
args.assessor_class_filename,
args.assessor_class_name,
args.assessor_args)
if assessor is None: if assessor is None:
raise AssertionError('Failed to create Assessor instance') raise AssertionError('Failed to create Assessor instance')
return assessor return assessor
......
...@@ -213,24 +213,18 @@ def validate_customized_file(experiment_config, spec_key): ...@@ -213,24 +213,18 @@ def validate_customized_file(experiment_config, spec_key):
def parse_tuner_content(experiment_config): def parse_tuner_content(experiment_config):
'''Validate whether tuner in experiment_config is valid''' '''Validate whether tuner in experiment_config is valid'''
if experiment_config['tuner'].get('builtinTunerName'): if not experiment_config['tuner'].get('builtinTunerName'):
experiment_config['tuner']['className'] = experiment_config['tuner']['builtinTunerName']
else:
validate_customized_file(experiment_config, 'tuner') validate_customized_file(experiment_config, 'tuner')
def parse_assessor_content(experiment_config): def parse_assessor_content(experiment_config):
'''Validate whether assessor in experiment_config is valid''' '''Validate whether assessor in experiment_config is valid'''
if experiment_config.get('assessor'): if experiment_config.get('assessor'):
if experiment_config['assessor'].get('builtinAssessorName'): if not experiment_config['assessor'].get('builtinAssessorName'):
experiment_config['assessor']['className'] = experiment_config['assessor']['builtinAssessorName']
else:
validate_customized_file(experiment_config, 'assessor') validate_customized_file(experiment_config, 'assessor')
def parse_advisor_content(experiment_config): def parse_advisor_content(experiment_config):
'''Validate whether advisor in experiment_config is valid''' '''Validate whether advisor in experiment_config is valid'''
if experiment_config['advisor'].get('builtinAdvisorName'): if not experiment_config['advisor'].get('builtinAdvisorName'):
experiment_config['advisor']['className'] = experiment_config['advisor']['builtinAdvisorName']
else:
validate_customized_file(experiment_config, 'advisor') validate_customized_file(experiment_config, 'advisor')
def validate_annotation_content(experiment_config, spec_key, builtin_name): def validate_annotation_content(experiment_config, spec_key, builtin_name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment