Unverified Commit f9ee589c authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #222 from microsoft/master

merge master
parents 36e6e350 4f3ee9cb
......@@ -41,6 +41,7 @@ class Experiment {
if (!this.profileField) {
throw Error('Experiment profile not initialized');
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.profileField!;
}
......@@ -73,6 +74,7 @@ class Experiment {
if (!this.statusField) {
throw Error('Experiment status not initialized');
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.statusField!.status;
}
......@@ -80,6 +82,7 @@ class Experiment {
if (!this.statusField) {
throw Error('Experiment status not initialized');
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.statusField!.errors[0] || '';
}
}
......
......@@ -19,10 +19,12 @@ class Trial implements TableObj {
if (!this.sortable || !otherTrial.sortable) {
return undefined;
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.finalAcc! - otherTrial.finalAcc!;
}
get info(): TrialJobInfo {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.infoField!;
}
......@@ -30,6 +32,7 @@ class Trial implements TableObj {
const ret: MetricDataRecord[] = [ ];
for (let i = 0; i < this.intermediates.length; i++) {
if (this.intermediates[i]) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
ret.push(this.intermediates[i]!);
} else {
break;
......@@ -66,12 +69,14 @@ class Trial implements TableObj {
get tableRecord(): TableRecord {
const endTime = this.info.endTime || new Date().getTime();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const duration = (endTime - this.info.startTime!) / 1000;
return {
key: this.info.id,
sequenceId: this.info.sequenceId,
id: this.info.id,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
startTime: this.info.startTime!,
endTime: this.info.endTime,
duration,
......@@ -97,6 +102,7 @@ class Trial implements TableObj {
get duration(): number {
const endTime = this.info.endTime || new Date().getTime();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return (endTime - this.info.startTime!) / 1000;
}
......@@ -203,6 +209,7 @@ class Trial implements TableObj {
} else if (this.intermediates.length === 0) {
return '--';
} else {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const latest = this.intermediates[this.intermediates.length - 1]!;
return `${formatAccuracy(metricAccuracy(latest))} (LATEST)`;
}
......
......@@ -7,6 +7,7 @@ function groupMetricsByTrial(metrics: MetricDataRecord[]): Map<string, MetricDat
const ret = new Map<string, MetricDataRecord[]>();
for (const metric of metrics) {
if (ret.has(metric.trialJobId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
ret.get(metric.trialJobId)!.push(metric);
} else {
ret.set(metric.trialJobId, [ metric ]);
......@@ -35,14 +36,17 @@ class TrialManager {
}
public getTrial(trialId: string): Trial {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.trials.get(trialId)!;
}
public getTrials(trialIds: string[]): Trial[] {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return trialIds.map(trialId => this.trials.get(trialId)!);
}
public table(trialIds: string[]): TableRecord[] {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return trialIds.map(trialId => this.trials.get(trialId)!.tableRecord);
}
......@@ -61,6 +65,7 @@ class TrialManager {
}
public sort(): Trial[] {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.filter(trial => trial.sortable).sort((trial1, trial2) => trial1.compareAccuracy(trial2)!);
}
......@@ -77,6 +82,7 @@ class TrialManager {
]);
for (const trial of this.trials.values()) {
if (trial.initialized()) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
cnt.set(trial.info.status, cnt.get(trial.info.status)! + 1);
}
}
......@@ -89,6 +95,7 @@ class TrialManager {
if (response.status === 200) {
for (const info of response.data as TrialJobInfo[]) {
if (this.trials.has(info.id)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
updated = this.trials.get(info.id)!.updateTrialJobInfo(info) || updated;
} else {
this.trials.set(info.id, new Trial(info, undefined));
......@@ -141,6 +148,7 @@ class TrialManager {
let updated = false;
for (const [ trialId, metrics ] of groupMetricsByTrial(allMetrics).entries()) {
if (this.trials.has(trialId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const trial = this.trials.get(trialId)!;
updated = (latestOnly ? trial.updateLatestMetrics(metrics) : trial.updateMetrics(metrics)) || updated;
} else {
......
......@@ -14,11 +14,11 @@ def update_training_service_config(args):
config[args.ts]['nniManagerIp'] = args.nni_manager_ip
if args.ts == 'pai':
if args.pai_user is not None:
config[args.ts]['paiConfig']['userName'] = args.pai_user
config[args.ts]['paiYarnConfig']['userName'] = args.pai_user
if args.pai_pwd is not None:
config[args.ts]['paiConfig']['passWord'] = args.pai_pwd
config[args.ts]['paiYarnConfig']['passWord'] = args.pai_pwd
if args.pai_host is not None:
config[args.ts]['paiConfig']['host'] = args.pai_host
config[args.ts]['paiYarnConfig']['host'] = args.pai_host
if args.nni_docker_image is not None:
config[args.ts]['trial']['image'] = args.nni_docker_image
if args.data_dir is not None:
......
......@@ -29,11 +29,11 @@ local:
pai:
nniManagerIp:
maxExecDuration: 15m
paiConfig:
paiYarnConfig:
host:
passWord:
userName:
trainingServicePlatform: pai
trainingServicePlatform: paiYarn
trial:
gpuNum: 1
cpuNum: 1
......
......@@ -32,7 +32,7 @@ common_schema = {
'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller'),
'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool),
......@@ -53,14 +53,23 @@ common_schema = {
}
}
tuner_schema_dict = {
('Anneal', 'SMAC'): {
'builtinTunerName': setChoice('builtinTunerName', 'Anneal', 'SMAC'),
'Anneal': {
'builtinTunerName': 'Anneal',
Optional('classArgs'): {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'SMAC': {
'builtinTunerName': 'SMAC',
Optional('classArgs'): {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
'config_dedup': setType('config_dedup', bool)
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
('Evolution'): {
'builtinTunerName': setChoice('builtinTunerName', 'Evolution'),
Optional('classArgs'): {
......@@ -223,7 +232,7 @@ common_trial_schema = {
}
}
pai_trial_schema = {
pai_yarn_trial_schema = {
'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
......@@ -247,6 +256,34 @@ pai_trial_schema = {
}
}
pai_yarn_config_schema = {
'paiYarnConfig': Or({
'userName': setType('userName', str),
'passWord': setType('passWord', str),
'host': setType('host', str)
}, {
'userName': setType('userName', str),
'token': setType('token', str),
'host': setType('host', str)
})
}
pai_trial_schema = {
'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
'containerNFSMountPath': setType('containerNFSMountPath', str),
'paiStoragePlugin': setType('paiStoragePlugin', str)
}
}
pai_config_schema = {
'paiConfig': Or({
'userName': setType('userName', str),
......@@ -396,6 +433,8 @@ REMOTE_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema, **machine
PAI_CONFIG_SCHEMA = Schema({**common_schema, **pai_trial_schema, **pai_config_schema})
PAI_YARN_CONFIG_SCHEMA = Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema})
KUBEFLOW_CONFIG_SCHEMA = Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema})
FRAMEWORKCONTROLLER_CONFIG_SCHEMA = Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema})
......@@ -224,6 +224,25 @@ def set_pai_config(experiment_config, port, config_file_name):
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_pai_yarn_config(experiment_config, port, config_file_name):
'''set paiYarn configuration'''
pai_yarn_config_data = dict()
pai_yarn_config_data['pai_yarn_config'] = experiment_config['paiYarnConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(pai_yarn_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_kubeflow_config(experiment_config, port, config_file_name):
'''set kubeflow configuration'''
kubeflow_config_data = dict()
......@@ -320,6 +339,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'pai_config', 'value': experiment_config['paiConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'paiYarn':
request_data['clusterMetaData'].append(
{'key': 'pai_yarn_config', 'value': experiment_config['paiYarnConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'kubeflow':
request_data['clusterMetaData'].append(
{'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
......@@ -351,6 +375,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
elif platform == 'pai':
config_result, err_msg = set_pai_config(experiment_config, port, config_file_name)
elif platform == 'paiYarn':
config_result, err_msg = set_pai_yarn_config(experiment_config, port, config_file_name)
elif platform == 'kubeflow':
config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
elif platform == 'frameworkcontroller':
......
......@@ -5,7 +5,7 @@ import os
import json
from schema import SchemaError
from schema import Schema
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from .common_utils import print_error, print_warning, print_normal
......@@ -143,13 +143,14 @@ def validate_kubeflow_operators(experiment_config):
def validate_common_content(experiment_config):
'''Validate whether the common values in experiment_config is valid'''
if not experiment_config.get('trainingServicePlatform') or \
experiment_config.get('trainingServicePlatform') not in ['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller']:
experiment_config.get('trainingServicePlatform') not in ['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn']:
print_error('Please set correct trainingServicePlatform!')
exit(1)
schema_dict = {
'local': LOCAL_CONFIG_SCHEMA,
'remote': REMOTE_CONFIG_SCHEMA,
'pai': PAI_CONFIG_SCHEMA,
'paiYarn': PAI_YARN_CONFIG_SCHEMA,
'kubeflow': KUBEFLOW_CONFIG_SCHEMA,
'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA
}
......@@ -213,24 +214,18 @@ def validate_customized_file(experiment_config, spec_key):
def parse_tuner_content(experiment_config):
'''Validate whether tuner in experiment_config is valid'''
if experiment_config['tuner'].get('builtinTunerName'):
experiment_config['tuner']['className'] = experiment_config['tuner']['builtinTunerName']
else:
if not experiment_config['tuner'].get('builtinTunerName'):
validate_customized_file(experiment_config, 'tuner')
def parse_assessor_content(experiment_config):
'''Validate whether assessor in experiment_config is valid'''
if experiment_config.get('assessor'):
if experiment_config['assessor'].get('builtinAssessorName'):
experiment_config['assessor']['className'] = experiment_config['assessor']['builtinAssessorName']
else:
if not experiment_config['assessor'].get('builtinAssessorName'):
validate_customized_file(experiment_config, 'assessor')
def parse_advisor_content(experiment_config):
'''Validate whether advisor in experiment_config is valid'''
if experiment_config['advisor'].get('builtinAdvisorName'):
experiment_config['advisor']['className'] = experiment_config['advisor']['builtinAdvisorName']
else:
if not experiment_config['advisor'].get('builtinAdvisorName'):
validate_customized_file(experiment_config, 'advisor')
def validate_annotation_content(experiment_config, spec_key, builtin_name):
......@@ -261,7 +256,7 @@ def validate_machine_list(experiment_config):
def validate_pai_trial_conifg(experiment_config):
'''validate the trial config in pai platform'''
if experiment_config.get('trainingServicePlatform') == 'pai':
if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
if experiment_config.get('trial').get('shmMB') and \
experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']:
print_error('shmMB should be no more than memoryMB!')
......
......@@ -682,10 +682,13 @@ def search_space_auto_gen(args):
trial_dir = os.path.expanduser(args.trial_dir)
file_path = os.path.expanduser(args.file)
if not os.path.isabs(file_path):
abs_file_path = os.path.join(os.getcwd(), file_path)
file_path = os.path.join(os.getcwd(), file_path)
assert os.path.exists(trial_dir)
if os.path.exists(abs_file_path):
print_warning('%s already exits, will be over written' % abs_file_path)
if os.path.exists(file_path):
print_warning('%s already exists, will be overwritten.' % file_path)
print_normal('Dry run to generate search space...')
Popen(args.trial_command, cwd=trial_dir, env=dict(os.environ, NNI_GEN_SEARCH_SPACE=abs_file_path), shell=True).wait()
print_normal('Dry run to generate search space, Done')
\ No newline at end of file
Popen(args.trial_command, cwd=trial_dir, env=dict(os.environ, NNI_GEN_SEARCH_SPACE=file_path), shell=True).wait()
if not os.path.exists(file_path):
print_warning('Expected search space file \'{}\' generated, but not found.'.format(file_path))
else:
print_normal('Generate search space done: \'{}\'.'.format(file_path))
......@@ -223,7 +223,7 @@ if __name__ == '__main__':
exit(1)
check_version(args)
try:
if NNI_PLATFORM == 'pai' and is_multi_phase():
if NNI_PLATFORM == 'paiYarn' and is_multi_phase():
fetch_parameter_file(args)
main_loop(args)
except SystemExit as se:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment