Unverified Commit f9ee589c authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #222 from microsoft/master

merge master
parents 36e6e350 4f3ee9cb
...@@ -41,6 +41,7 @@ class Experiment { ...@@ -41,6 +41,7 @@ class Experiment {
if (!this.profileField) { if (!this.profileField) {
throw Error('Experiment profile not initialized'); throw Error('Experiment profile not initialized');
} }
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.profileField!; return this.profileField!;
} }
...@@ -73,6 +74,7 @@ class Experiment { ...@@ -73,6 +74,7 @@ class Experiment {
if (!this.statusField) { if (!this.statusField) {
throw Error('Experiment status not initialized'); throw Error('Experiment status not initialized');
} }
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.statusField!.status; return this.statusField!.status;
} }
...@@ -80,6 +82,7 @@ class Experiment { ...@@ -80,6 +82,7 @@ class Experiment {
if (!this.statusField) { if (!this.statusField) {
throw Error('Experiment status not initialized'); throw Error('Experiment status not initialized');
} }
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.statusField!.errors[0] || ''; return this.statusField!.errors[0] || '';
} }
} }
......
...@@ -19,10 +19,12 @@ class Trial implements TableObj { ...@@ -19,10 +19,12 @@ class Trial implements TableObj {
if (!this.sortable || !otherTrial.sortable) { if (!this.sortable || !otherTrial.sortable) {
return undefined; return undefined;
} }
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.finalAcc! - otherTrial.finalAcc!; return this.finalAcc! - otherTrial.finalAcc!;
} }
get info(): TrialJobInfo { get info(): TrialJobInfo {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.infoField!; return this.infoField!;
} }
...@@ -30,6 +32,7 @@ class Trial implements TableObj { ...@@ -30,6 +32,7 @@ class Trial implements TableObj {
const ret: MetricDataRecord[] = [ ]; const ret: MetricDataRecord[] = [ ];
for (let i = 0; i < this.intermediates.length; i++) { for (let i = 0; i < this.intermediates.length; i++) {
if (this.intermediates[i]) { if (this.intermediates[i]) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
ret.push(this.intermediates[i]!); ret.push(this.intermediates[i]!);
} else { } else {
break; break;
...@@ -66,12 +69,14 @@ class Trial implements TableObj { ...@@ -66,12 +69,14 @@ class Trial implements TableObj {
get tableRecord(): TableRecord { get tableRecord(): TableRecord {
const endTime = this.info.endTime || new Date().getTime(); const endTime = this.info.endTime || new Date().getTime();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const duration = (endTime - this.info.startTime!) / 1000; const duration = (endTime - this.info.startTime!) / 1000;
return { return {
key: this.info.id, key: this.info.id,
sequenceId: this.info.sequenceId, sequenceId: this.info.sequenceId,
id: this.info.id, id: this.info.id,
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
startTime: this.info.startTime!, startTime: this.info.startTime!,
endTime: this.info.endTime, endTime: this.info.endTime,
duration, duration,
...@@ -97,6 +102,7 @@ class Trial implements TableObj { ...@@ -97,6 +102,7 @@ class Trial implements TableObj {
get duration(): number { get duration(): number {
const endTime = this.info.endTime || new Date().getTime(); const endTime = this.info.endTime || new Date().getTime();
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return (endTime - this.info.startTime!) / 1000; return (endTime - this.info.startTime!) / 1000;
} }
...@@ -203,6 +209,7 @@ class Trial implements TableObj { ...@@ -203,6 +209,7 @@ class Trial implements TableObj {
} else if (this.intermediates.length === 0) { } else if (this.intermediates.length === 0) {
return '--'; return '--';
} else { } else {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const latest = this.intermediates[this.intermediates.length - 1]!; const latest = this.intermediates[this.intermediates.length - 1]!;
return `${formatAccuracy(metricAccuracy(latest))} (LATEST)`; return `${formatAccuracy(metricAccuracy(latest))} (LATEST)`;
} }
......
...@@ -7,6 +7,7 @@ function groupMetricsByTrial(metrics: MetricDataRecord[]): Map<string, MetricDat ...@@ -7,6 +7,7 @@ function groupMetricsByTrial(metrics: MetricDataRecord[]): Map<string, MetricDat
const ret = new Map<string, MetricDataRecord[]>(); const ret = new Map<string, MetricDataRecord[]>();
for (const metric of metrics) { for (const metric of metrics) {
if (ret.has(metric.trialJobId)) { if (ret.has(metric.trialJobId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
ret.get(metric.trialJobId)!.push(metric); ret.get(metric.trialJobId)!.push(metric);
} else { } else {
ret.set(metric.trialJobId, [ metric ]); ret.set(metric.trialJobId, [ metric ]);
...@@ -35,14 +36,17 @@ class TrialManager { ...@@ -35,14 +36,17 @@ class TrialManager {
} }
public getTrial(trialId: string): Trial { public getTrial(trialId: string): Trial {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.trials.get(trialId)!; return this.trials.get(trialId)!;
} }
public getTrials(trialIds: string[]): Trial[] { public getTrials(trialIds: string[]): Trial[] {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return trialIds.map(trialId => this.trials.get(trialId)!); return trialIds.map(trialId => this.trials.get(trialId)!);
} }
public table(trialIds: string[]): TableRecord[] { public table(trialIds: string[]): TableRecord[] {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return trialIds.map(trialId => this.trials.get(trialId)!.tableRecord); return trialIds.map(trialId => this.trials.get(trialId)!.tableRecord);
} }
...@@ -61,6 +65,7 @@ class TrialManager { ...@@ -61,6 +65,7 @@ class TrialManager {
} }
public sort(): Trial[] { public sort(): Trial[] {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return this.filter(trial => trial.sortable).sort((trial1, trial2) => trial1.compareAccuracy(trial2)!); return this.filter(trial => trial.sortable).sort((trial1, trial2) => trial1.compareAccuracy(trial2)!);
} }
...@@ -77,6 +82,7 @@ class TrialManager { ...@@ -77,6 +82,7 @@ class TrialManager {
]); ]);
for (const trial of this.trials.values()) { for (const trial of this.trials.values()) {
if (trial.initialized()) { if (trial.initialized()) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
cnt.set(trial.info.status, cnt.get(trial.info.status)! + 1); cnt.set(trial.info.status, cnt.get(trial.info.status)! + 1);
} }
} }
...@@ -89,6 +95,7 @@ class TrialManager { ...@@ -89,6 +95,7 @@ class TrialManager {
if (response.status === 200) { if (response.status === 200) {
for (const info of response.data as TrialJobInfo[]) { for (const info of response.data as TrialJobInfo[]) {
if (this.trials.has(info.id)) { if (this.trials.has(info.id)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
updated = this.trials.get(info.id)!.updateTrialJobInfo(info) || updated; updated = this.trials.get(info.id)!.updateTrialJobInfo(info) || updated;
} else { } else {
this.trials.set(info.id, new Trial(info, undefined)); this.trials.set(info.id, new Trial(info, undefined));
...@@ -141,6 +148,7 @@ class TrialManager { ...@@ -141,6 +148,7 @@ class TrialManager {
let updated = false; let updated = false;
for (const [ trialId, metrics ] of groupMetricsByTrial(allMetrics).entries()) { for (const [ trialId, metrics ] of groupMetricsByTrial(allMetrics).entries()) {
if (this.trials.has(trialId)) { if (this.trials.has(trialId)) {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const trial = this.trials.get(trialId)!; const trial = this.trials.get(trialId)!;
updated = (latestOnly ? trial.updateLatestMetrics(metrics) : trial.updateMetrics(metrics)) || updated; updated = (latestOnly ? trial.updateLatestMetrics(metrics) : trial.updateMetrics(metrics)) || updated;
} else { } else {
......
...@@ -14,11 +14,11 @@ def update_training_service_config(args): ...@@ -14,11 +14,11 @@ def update_training_service_config(args):
config[args.ts]['nniManagerIp'] = args.nni_manager_ip config[args.ts]['nniManagerIp'] = args.nni_manager_ip
if args.ts == 'pai': if args.ts == 'pai':
if args.pai_user is not None: if args.pai_user is not None:
config[args.ts]['paiConfig']['userName'] = args.pai_user config[args.ts]['paiYarnConfig']['userName'] = args.pai_user
if args.pai_pwd is not None: if args.pai_pwd is not None:
config[args.ts]['paiConfig']['passWord'] = args.pai_pwd config[args.ts]['paiYarnConfig']['passWord'] = args.pai_pwd
if args.pai_host is not None: if args.pai_host is not None:
config[args.ts]['paiConfig']['host'] = args.pai_host config[args.ts]['paiYarnConfig']['host'] = args.pai_host
if args.nni_docker_image is not None: if args.nni_docker_image is not None:
config[args.ts]['trial']['image'] = args.nni_docker_image config[args.ts]['trial']['image'] = args.nni_docker_image
if args.data_dir is not None: if args.data_dir is not None:
......
...@@ -29,11 +29,11 @@ local: ...@@ -29,11 +29,11 @@ local:
pai: pai:
nniManagerIp: nniManagerIp:
maxExecDuration: 15m maxExecDuration: 15m
paiConfig: paiYarnConfig:
host: host:
passWord: passWord:
userName: userName:
trainingServicePlatform: pai trainingServicePlatform: paiYarn
trial: trial:
gpuNum: 1 gpuNum: 1
cpuNum: 1 cpuNum: 1
......
...@@ -32,7 +32,7 @@ common_schema = { ...@@ -32,7 +32,7 @@ common_schema = {
'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999), 'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999), Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller'), 'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'), Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool), Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool), Optional('multiThread'): setType('multiThread', bool),
...@@ -53,14 +53,23 @@ common_schema = { ...@@ -53,14 +53,23 @@ common_schema = {
} }
} }
tuner_schema_dict = { tuner_schema_dict = {
('Anneal', 'SMAC'): { 'Anneal': {
'builtinTunerName': setChoice('builtinTunerName', 'Anneal', 'SMAC'), 'builtinTunerName': 'Anneal',
Optional('classArgs'): { Optional('classArgs'): {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'), 'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
'SMAC': {
'builtinTunerName': 'SMAC',
Optional('classArgs'): {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
'config_dedup': setType('config_dedup', bool)
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
('Evolution'): { ('Evolution'): {
'builtinTunerName': setChoice('builtinTunerName', 'Evolution'), 'builtinTunerName': setChoice('builtinTunerName', 'Evolution'),
Optional('classArgs'): { Optional('classArgs'): {
...@@ -223,7 +232,7 @@ common_trial_schema = { ...@@ -223,7 +232,7 @@ common_trial_schema = {
} }
} }
pai_trial_schema = { pai_yarn_trial_schema = {
'trial':{ 'trial':{
'command': setType('command', str), 'command': setType('command', str),
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
...@@ -247,6 +256,34 @@ pai_trial_schema = { ...@@ -247,6 +256,34 @@ pai_trial_schema = {
} }
} }
pai_yarn_config_schema = {
'paiYarnConfig': Or({
'userName': setType('userName', str),
'passWord': setType('passWord', str),
'host': setType('host', str)
}, {
'userName': setType('userName', str),
'token': setType('token', str),
'host': setType('host', str)
})
}
pai_trial_schema = {
'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
'containerNFSMountPath': setType('containerNFSMountPath', str),
'paiStoragePlugin': setType('paiStoragePlugin', str)
}
}
pai_config_schema = { pai_config_schema = {
'paiConfig': Or({ 'paiConfig': Or({
'userName': setType('userName', str), 'userName': setType('userName', str),
...@@ -396,6 +433,8 @@ REMOTE_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema, **machine ...@@ -396,6 +433,8 @@ REMOTE_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema, **machine
PAI_CONFIG_SCHEMA = Schema({**common_schema, **pai_trial_schema, **pai_config_schema}) PAI_CONFIG_SCHEMA = Schema({**common_schema, **pai_trial_schema, **pai_config_schema})
PAI_YARN_CONFIG_SCHEMA = Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema})
KUBEFLOW_CONFIG_SCHEMA = Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema}) KUBEFLOW_CONFIG_SCHEMA = Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema})
FRAMEWORKCONTROLLER_CONFIG_SCHEMA = Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}) FRAMEWORKCONTROLLER_CONFIG_SCHEMA = Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema})
...@@ -224,6 +224,25 @@ def set_pai_config(experiment_config, port, config_file_name): ...@@ -224,6 +224,25 @@ def set_pai_config(experiment_config, port, config_file_name):
#set trial_config #set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message return set_trial_config(experiment_config, port, config_file_name), err_message
def set_pai_yarn_config(experiment_config, port, config_file_name):
'''set paiYarn configuration'''
pai_yarn_config_data = dict()
pai_yarn_config_data['pai_yarn_config'] = experiment_config['paiYarnConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(pai_yarn_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_kubeflow_config(experiment_config, port, config_file_name): def set_kubeflow_config(experiment_config, port, config_file_name):
'''set kubeflow configuration''' '''set kubeflow configuration'''
kubeflow_config_data = dict() kubeflow_config_data = dict()
...@@ -320,6 +339,11 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -320,6 +339,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'pai_config', 'value': experiment_config['paiConfig']}) {'key': 'pai_config', 'value': experiment_config['paiConfig']})
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']}) {'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'paiYarn':
request_data['clusterMetaData'].append(
{'key': 'pai_yarn_config', 'value': experiment_config['paiYarnConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'kubeflow': elif experiment_config['trainingServicePlatform'] == 'kubeflow':
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']}) {'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
...@@ -351,6 +375,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -351,6 +375,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result, err_msg = set_remote_config(experiment_config, port, config_file_name) config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
elif platform == 'pai': elif platform == 'pai':
config_result, err_msg = set_pai_config(experiment_config, port, config_file_name) config_result, err_msg = set_pai_config(experiment_config, port, config_file_name)
elif platform == 'paiYarn':
config_result, err_msg = set_pai_yarn_config(experiment_config, port, config_file_name)
elif platform == 'kubeflow': elif platform == 'kubeflow':
config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name) config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
elif platform == 'frameworkcontroller': elif platform == 'frameworkcontroller':
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import json import json
from schema import SchemaError from schema import SchemaError
from schema import Schema from schema import Schema
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\ from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from .common_utils import print_error, print_warning, print_normal from .common_utils import print_error, print_warning, print_normal
...@@ -143,13 +143,14 @@ def validate_kubeflow_operators(experiment_config): ...@@ -143,13 +143,14 @@ def validate_kubeflow_operators(experiment_config):
def validate_common_content(experiment_config): def validate_common_content(experiment_config):
'''Validate whether the common values in experiment_config is valid''' '''Validate whether the common values in experiment_config is valid'''
if not experiment_config.get('trainingServicePlatform') or \ if not experiment_config.get('trainingServicePlatform') or \
experiment_config.get('trainingServicePlatform') not in ['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller']: experiment_config.get('trainingServicePlatform') not in ['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn']:
print_error('Please set correct trainingServicePlatform!') print_error('Please set correct trainingServicePlatform!')
exit(1) exit(1)
schema_dict = { schema_dict = {
'local': LOCAL_CONFIG_SCHEMA, 'local': LOCAL_CONFIG_SCHEMA,
'remote': REMOTE_CONFIG_SCHEMA, 'remote': REMOTE_CONFIG_SCHEMA,
'pai': PAI_CONFIG_SCHEMA, 'pai': PAI_CONFIG_SCHEMA,
'paiYarn': PAI_YARN_CONFIG_SCHEMA,
'kubeflow': KUBEFLOW_CONFIG_SCHEMA, 'kubeflow': KUBEFLOW_CONFIG_SCHEMA,
'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA 'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA
} }
...@@ -213,24 +214,18 @@ def validate_customized_file(experiment_config, spec_key): ...@@ -213,24 +214,18 @@ def validate_customized_file(experiment_config, spec_key):
def parse_tuner_content(experiment_config): def parse_tuner_content(experiment_config):
'''Validate whether tuner in experiment_config is valid''' '''Validate whether tuner in experiment_config is valid'''
if experiment_config['tuner'].get('builtinTunerName'): if not experiment_config['tuner'].get('builtinTunerName'):
experiment_config['tuner']['className'] = experiment_config['tuner']['builtinTunerName']
else:
validate_customized_file(experiment_config, 'tuner') validate_customized_file(experiment_config, 'tuner')
def parse_assessor_content(experiment_config): def parse_assessor_content(experiment_config):
'''Validate whether assessor in experiment_config is valid''' '''Validate whether assessor in experiment_config is valid'''
if experiment_config.get('assessor'): if experiment_config.get('assessor'):
if experiment_config['assessor'].get('builtinAssessorName'): if not experiment_config['assessor'].get('builtinAssessorName'):
experiment_config['assessor']['className'] = experiment_config['assessor']['builtinAssessorName']
else:
validate_customized_file(experiment_config, 'assessor') validate_customized_file(experiment_config, 'assessor')
def parse_advisor_content(experiment_config): def parse_advisor_content(experiment_config):
'''Validate whether advisor in experiment_config is valid''' '''Validate whether advisor in experiment_config is valid'''
if experiment_config['advisor'].get('builtinAdvisorName'): if not experiment_config['advisor'].get('builtinAdvisorName'):
experiment_config['advisor']['className'] = experiment_config['advisor']['builtinAdvisorName']
else:
validate_customized_file(experiment_config, 'advisor') validate_customized_file(experiment_config, 'advisor')
def validate_annotation_content(experiment_config, spec_key, builtin_name): def validate_annotation_content(experiment_config, spec_key, builtin_name):
...@@ -261,7 +256,7 @@ def validate_machine_list(experiment_config): ...@@ -261,7 +256,7 @@ def validate_machine_list(experiment_config):
def validate_pai_trial_conifg(experiment_config): def validate_pai_trial_conifg(experiment_config):
'''validate the trial config in pai platform''' '''validate the trial config in pai platform'''
if experiment_config.get('trainingServicePlatform') == 'pai': if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
if experiment_config.get('trial').get('shmMB') and \ if experiment_config.get('trial').get('shmMB') and \
experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']: experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']:
print_error('shmMB should be no more than memoryMB!') print_error('shmMB should be no more than memoryMB!')
......
...@@ -682,10 +682,13 @@ def search_space_auto_gen(args): ...@@ -682,10 +682,13 @@ def search_space_auto_gen(args):
trial_dir = os.path.expanduser(args.trial_dir) trial_dir = os.path.expanduser(args.trial_dir)
file_path = os.path.expanduser(args.file) file_path = os.path.expanduser(args.file)
if not os.path.isabs(file_path): if not os.path.isabs(file_path):
abs_file_path = os.path.join(os.getcwd(), file_path) file_path = os.path.join(os.getcwd(), file_path)
assert os.path.exists(trial_dir) assert os.path.exists(trial_dir)
if os.path.exists(abs_file_path): if os.path.exists(file_path):
print_warning('%s already exits, will be over written' % abs_file_path) print_warning('%s already exists, will be overwritten.' % file_path)
print_normal('Dry run to generate search space...') print_normal('Dry run to generate search space...')
Popen(args.trial_command, cwd=trial_dir, env=dict(os.environ, NNI_GEN_SEARCH_SPACE=abs_file_path), shell=True).wait() Popen(args.trial_command, cwd=trial_dir, env=dict(os.environ, NNI_GEN_SEARCH_SPACE=file_path), shell=True).wait()
print_normal('Dry run to generate search space, Done') if not os.path.exists(file_path):
\ No newline at end of file print_warning('Expected search space file \'{}\' generated, but not found.'.format(file_path))
else:
print_normal('Generate search space done: \'{}\'.'.format(file_path))
...@@ -223,7 +223,7 @@ if __name__ == '__main__': ...@@ -223,7 +223,7 @@ if __name__ == '__main__':
exit(1) exit(1)
check_version(args) check_version(args)
try: try:
if NNI_PLATFORM == 'pai' and is_multi_phase(): if NNI_PLATFORM == 'paiYarn' and is_multi_phase():
fetch_parameter_file(args) fetch_parameter_file(args)
main_loop(args) main_loop(args)
except SystemExit as se: except SystemExit as se:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment