Unverified Commit 9cbbf6f8 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support pai and paiYarn trainingservice (#1853)

parent 9d01d083
......@@ -14,11 +14,11 @@ def update_training_service_config(args):
config[args.ts]['nniManagerIp'] = args.nni_manager_ip
if args.ts == 'pai':
if args.pai_user is not None:
config[args.ts]['paiConfig']['userName'] = args.pai_user
config[args.ts]['paiYarnConfig']['userName'] = args.pai_user
if args.pai_pwd is not None:
config[args.ts]['paiConfig']['passWord'] = args.pai_pwd
config[args.ts]['paiYarnConfig']['passWord'] = args.pai_pwd
if args.pai_host is not None:
config[args.ts]['paiConfig']['host'] = args.pai_host
config[args.ts]['paiYarnConfig']['host'] = args.pai_host
if args.nni_docker_image is not None:
config[args.ts]['trial']['image'] = args.nni_docker_image
if args.data_dir is not None:
......
......@@ -29,11 +29,11 @@ local:
pai:
nniManagerIp:
maxExecDuration: 15m
paiConfig:
paiYarnConfig:
host:
passWord:
userName:
trainingServicePlatform: pai
trainingServicePlatform: paiYarn
trial:
gpuNum: 1
cpuNum: 1
......
......@@ -32,7 +32,7 @@ common_schema = {
'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller'),
'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool),
......@@ -232,7 +232,7 @@ common_trial_schema = {
}
}
pai_trial_schema = {
pai_yarn_trial_schema = {
'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
......@@ -256,6 +256,34 @@ pai_trial_schema = {
}
}
pai_yarn_config_schema = {
'paiYarnConfig': Or({
'userName': setType('userName', str),
'passWord': setType('passWord', str),
'host': setType('host', str)
}, {
'userName': setType('userName', str),
'token': setType('token', str),
'host': setType('host', str)
})
}
pai_trial_schema = {
'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
'containerNFSMountPath': setType('containerNFSMountPath', str),
'paiStoragePlugin': setType('paiStoragePlugin', str)
}
}
pai_config_schema = {
'paiConfig': Or({
'userName': setType('userName', str),
......@@ -405,6 +433,8 @@ REMOTE_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema, **machine
PAI_CONFIG_SCHEMA = Schema({**common_schema, **pai_trial_schema, **pai_config_schema})
PAI_YARN_CONFIG_SCHEMA = Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema})
KUBEFLOW_CONFIG_SCHEMA = Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema})
FRAMEWORKCONTROLLER_CONFIG_SCHEMA = Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema})
......@@ -224,6 +224,25 @@ def set_pai_config(experiment_config, port, config_file_name):
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_pai_yarn_config(experiment_config, port, config_file_name):
'''set paiYarn configuration'''
pai_yarn_config_data = dict()
pai_yarn_config_data['pai_yarn_config'] = experiment_config['paiYarnConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(pai_yarn_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_kubeflow_config(experiment_config, port, config_file_name):
'''set kubeflow configuration'''
kubeflow_config_data = dict()
......@@ -320,6 +339,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'pai_config', 'value': experiment_config['paiConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'paiYarn':
request_data['clusterMetaData'].append(
{'key': 'pai_yarn_config', 'value': experiment_config['paiYarnConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'kubeflow':
request_data['clusterMetaData'].append(
{'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
......@@ -351,6 +375,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
elif platform == 'pai':
config_result, err_msg = set_pai_config(experiment_config, port, config_file_name)
elif platform == 'paiYarn':
config_result, err_msg = set_pai_yarn_config(experiment_config, port, config_file_name)
elif platform == 'kubeflow':
config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
elif platform == 'frameworkcontroller':
......
......@@ -5,7 +5,7 @@ import os
import json
from schema import SchemaError
from schema import Schema
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from .common_utils import print_error, print_warning, print_normal
......@@ -143,13 +143,14 @@ def validate_kubeflow_operators(experiment_config):
def validate_common_content(experiment_config):
'''Validate whether the common values in experiment_config is valid'''
if not experiment_config.get('trainingServicePlatform') or \
experiment_config.get('trainingServicePlatform') not in ['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller']:
experiment_config.get('trainingServicePlatform') not in ['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn']:
print_error('Please set correct trainingServicePlatform!')
exit(1)
schema_dict = {
'local': LOCAL_CONFIG_SCHEMA,
'remote': REMOTE_CONFIG_SCHEMA,
'pai': PAI_CONFIG_SCHEMA,
'paiYarn': PAI_YARN_CONFIG_SCHEMA,
'kubeflow': KUBEFLOW_CONFIG_SCHEMA,
'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA
}
......@@ -255,7 +256,7 @@ def validate_machine_list(experiment_config):
def validate_pai_trial_conifg(experiment_config):
'''validate the trial config in pai platform'''
if experiment_config.get('trainingServicePlatform') == 'pai':
if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
if experiment_config.get('trial').get('shmMB') and \
experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']:
print_error('shmMB should be no more than memoryMB!')
......
......@@ -223,7 +223,7 @@ if __name__ == '__main__':
exit(1)
check_version(args)
try:
if NNI_PLATFORM == 'pai' and is_multi_phase():
if NNI_PLATFORM == 'paiYarn' and is_multi_phase():
fetch_parameter_file(args)
main_loop(args)
except SystemExit as se:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment