Unverified Commit 8a60d624 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support paiStorageConfigName (#2536)

parent 52f71f54
...@@ -41,8 +41,8 @@ def update_training_service_config(args): ...@@ -41,8 +41,8 @@ def update_training_service_config(args):
config[args.ts]['trial']['nniManagerNFSMountPath'] = args.nni_manager_nfs_mount_path config[args.ts]['trial']['nniManagerNFSMountPath'] = args.nni_manager_nfs_mount_path
if args.container_nfs_mount_path is not None: if args.container_nfs_mount_path is not None:
config[args.ts]['trial']['containerNFSMountPath'] = args.container_nfs_mount_path config[args.ts]['trial']['containerNFSMountPath'] = args.container_nfs_mount_path
if args.pai_storage_plugin is not None: if args.pai_storage_config_name is not None:
config[args.ts]['trial']['paiStoragePlugin'] = args.pai_storage_plugin config[args.ts]['trial']['paiStorageConfigName'] = args.pai_storage_config_name
if args.vc is not None: if args.vc is not None:
config[args.ts]['trial']['virtualCluster'] = args.vc config[args.ts]['trial']['virtualCluster'] = args.vc
elif args.ts == 'kubeflow': elif args.ts == 'kubeflow':
...@@ -102,6 +102,7 @@ if __name__ == '__main__': ...@@ -102,6 +102,7 @@ if __name__ == '__main__':
parser.add_argument("--vc", type=str) parser.add_argument("--vc", type=str)
parser.add_argument("--pai_token", type=str) parser.add_argument("--pai_token", type=str)
parser.add_argument("--pai_storage_plugin", type=str) parser.add_argument("--pai_storage_plugin", type=str)
parser.add_argument("--pai_storage_config_name", type=str)
parser.add_argument("--nni_manager_nfs_mount_path", type=str) parser.add_argument("--nni_manager_nfs_mount_path", type=str)
parser.add_argument("--container_nfs_mount_path", type=str) parser.add_argument("--container_nfs_mount_path", type=str)
# args for kubeflow and frameworkController # args for kubeflow and frameworkController
......
...@@ -292,7 +292,7 @@ pai_trial_schema = { ...@@ -292,7 +292,7 @@ pai_trial_schema = {
Optional('memoryMB'): setType('memoryMB', int), Optional('memoryMB'): setType('memoryMB', int),
Optional('image'): setType('image', str), Optional('image'): setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str), Optional('virtualCluster'): setType('virtualCluster', str),
Optional('paiStoragePlugin'): setType('paiStoragePlugin', str), Optional('paiStorageConfigName'): setType('paiStorageConfigName', str),
Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath') Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
} }
} }
......
...@@ -273,7 +273,7 @@ def validate_pai_config_path(experiment_config): ...@@ -273,7 +273,7 @@ def validate_pai_config_path(experiment_config):
print_error('Please set taskRoles in paiConfigPath config file!') print_error('Please set taskRoles in paiConfigPath config file!')
exit(1) exit(1)
else: else:
pai_trial_fields_required_list = ['image', 'gpuNum', 'cpuNum', 'memoryMB', 'paiStoragePlugin', 'command'] pai_trial_fields_required_list = ['image', 'gpuNum', 'cpuNum', 'memoryMB', 'paiStorageConfigName', 'command']
for trial_field in pai_trial_fields_required_list: for trial_field in pai_trial_fields_required_list:
if experiment_config['trial'].get(trial_field) is None: if experiment_config['trial'].get(trial_field) is None:
print_error('Please set {0} in trial configuration,\ print_error('Please set {0} in trial configuration,\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment