config_schema.py 25.1 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4

import os
chicm-ms's avatar
chicm-ms committed
5
6
7
8
import json
import netifaces
from schema import Schema, And, Optional, Regex, Or, SchemaError
from nni.package_utils import create_validator_instance, get_all_builtin_names, get_builtin_algo_meta
9
from .constants import SCHEMA_TYPE_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_PATH_ERROR
chicm-ms's avatar
chicm-ms committed
10
from .common_utils import get_yml_content, print_warning
11
12


chicm-ms's avatar
chicm-ms committed
13
def setType(key, valueType):
14
    '''check key type'''
chicm-ms's avatar
chicm-ms committed
15
    return And(valueType, error=SCHEMA_TYPE_ERROR % (key, valueType.__name__))
16

17

18
19
20
21
def setChoice(key, *args):
    '''check choice'''
    return And(lambda n: n in args, error=SCHEMA_RANGE_ERROR % (key, str(args)))

22

23
24
25
26
27
28
29
def setNumberRange(key, keyType, start, end):
    '''check number range'''
    return And(
        And(keyType, error=SCHEMA_TYPE_ERROR % (key, keyType.__name__)),
        And(lambda n: start <= n <= end, error=SCHEMA_RANGE_ERROR % (key, '(%s,%s)' % (start, end))),
    )

30

31
32
33
def setPathCheck(key):
    '''check if path exist'''
    return And(os.path.exists, error=SCHEMA_PATH_ERROR % key)
34

35

chicm-ms's avatar
chicm-ms committed
36
37
38
39
40
41
class AlgoSchema:
    """
    This class is the schema of 'tuner', 'assessor' and 'advisor' sections of experiment configuraion file.
    For example:
    AlgoSchema('tuner') creates the schema of tuner section.
    """
42

chicm-ms's avatar
chicm-ms committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    def __init__(self, algo_type):
        """
        Parameters:
        -----------
        algo_type: str
            One of ['tuner', 'assessor', 'advisor'].
            'tuner': This AlgoSchema class create the schema of tuner section.
            'assessor': This AlgoSchema class create the schema of assessor section.
            'advisor': This AlgoSchema class create the schema of advisor section.
        """
        assert algo_type in ['tuner', 'assessor', 'advisor']
        self.algo_type = algo_type
        self.algo_schema = {
            Optional('codeDir'): setPathCheck('codeDir'),
            Optional('classFileName'): setType('classFileName', str),
            Optional('className'): setType('className', str),
            Optional('classArgs'): dict,
            Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
        }
        self.builtin_keys = {
            'tuner': 'builtinTunerName',
            'assessor': 'builtinAssessorName',
            'advisor': 'builtinAdvisorName'
        }
        self.builtin_name_schema = {}
        for k, n in self.builtin_keys.items():
            self.builtin_name_schema[k] = {Optional(n): setChoice(n, *get_all_builtin_names(k+'s'))}

        self.customized_keys = set(['codeDir', 'classFileName', 'className'])

    def validate_class_args(self, class_args, algo_type, builtin_name):
        if not builtin_name or not class_args:
            return
        meta = get_builtin_algo_meta(algo_type+'s', builtin_name)
        if meta and 'accept_class_args' in meta and meta['accept_class_args'] == False:
            raise SchemaError('classArgs is not allowed.')

        validator = create_validator_instance(algo_type+'s', builtin_name)
        if validator:
            try:
                validator.validate_class_args(**class_args)
            except Exception as e:
                raise SchemaError(str(e))

    def missing_customized_keys(self, data):
        return self.customized_keys - set(data.keys())

    def validate_extras(self, data, algo_type):
        builtin_key = self.builtin_keys[algo_type]
        if (builtin_key in data) and (set(data.keys()) & self.customized_keys):
            raise SchemaError('{} and {} cannot be specified at the same time.'.format(
                builtin_key, set(data.keys()) & self.customized_keys
            ))

        if self.missing_customized_keys(data) and builtin_key not in data:
            raise SchemaError('Either customized {} ({}) or builtin {} ({}) must be set.'.format(
                algo_type, self.customized_keys, algo_type, builtin_key))

        if not self.missing_customized_keys(data):
            class_file_name = os.path.join(data['codeDir'], data['classFileName'])
            if not os.path.isfile(class_file_name):
                raise SchemaError('classFileName {} not found.'.format(class_file_name))

        builtin_name = data.get(builtin_key)
        class_args = data.get('classArgs')
        self.validate_class_args(class_args, algo_type, builtin_name)

    def validate(self, data):
        self.algo_schema.update(self.builtin_name_schema[self.algo_type])
        Schema(self.algo_schema).validate(data)
        self.validate_extras(data, self.algo_type)

116

117
common_schema = {
118
119
120
121
    'authorName': setType('authorName', str),
    'experimentName': setType('experimentName', str),
    Optional('description'): setType('description', str),
    'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
chicm-ms's avatar
chicm-ms committed
122
    Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
123
    Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
George Cheng's avatar
George Cheng committed
124
    'trainingServicePlatform': setChoice(
SparkSnail's avatar
SparkSnail committed
125
        'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'),
126
127
128
129
130
131
    Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
    Optional('multiPhase'): setType('multiPhase', bool),
    Optional('multiThread'): setType('multiThread', bool),
    Optional('nniManagerIp'): setType('nniManagerIp', str),
    Optional('logDir'): And(os.path.isdir, error=SCHEMA_PATH_ERROR % 'logDir'),
    Optional('debug'): setType('debug', bool),
132
    Optional('versionCheck'): setType('versionCheck', bool),
133
134
135
    Optional('logLevel'): setChoice('logLevel', 'trace', 'debug', 'info', 'warning', 'error', 'fatal'),
    Optional('logCollection'): setChoice('logCollection', 'http', 'none'),
    'useAnnotation': setType('useAnnotation', bool),
chicm-ms's avatar
chicm-ms committed
136
137
138
    Optional('tuner'): AlgoSchema('tuner'),
    Optional('advisor'): AlgoSchema('advisor'),
    Optional('assessor'): AlgoSchema('assessor'),
139
    Optional('localConfig'): {
140
141
142
        Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
        Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
        Optional('useActiveGpu'): setType('useActiveGpu', bool)
143
144
    }
}
145
146

common_trial_schema = {
147
    'trial': {
chicm-ms's avatar
chicm-ms committed
148
149
150
151
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode')
152
153
154
    }
}

155
pai_yarn_trial_schema = {
156
    'trial': {
chicm-ms's avatar
chicm-ms committed
157
158
159
160
161
162
163
164
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
        'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
        'memoryMB': setType('memoryMB', int),
        'image': setType('image', str),
        Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'),
        Optional('shmMB'): setType('shmMB', int),
165
166
167
168
        Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
                                 error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
        Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
                                   error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
chicm-ms's avatar
chicm-ms committed
169
170
171
172
173
174
175
        Optional('virtualCluster'): setType('virtualCluster', str),
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
        Optional('portList'): [{
            "label": setType('label', str),
            "beginAt": setType('beginAt', int),
            "portNumber": setType('portNumber', int)
        }]
176
177
178
    }
}

179
180
181
182
183
184
185
186
187
188
189
190
191
192
pai_yarn_config_schema = {
    'paiYarnConfig': Or({
        'userName': setType('userName', str),
        'passWord': setType('passWord', str),
        'host': setType('host', str)
    }, {
        'userName': setType('userName', str),
        'token': setType('token', str),
        'host': setType('host', str)
    })
}


pai_trial_schema = {
193
    'trial': {
194
195
196
        'codeDir': setPathCheck('codeDir'),
        'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
        'containerNFSMountPath': setType('containerNFSMountPath', str),
197
        Optional('command'): setType('command', str),
SparkSnail's avatar
SparkSnail committed
198
199
200
201
202
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
        Optional('memoryMB'): setType('memoryMB', int),
        Optional('image'): setType('image', str),
        Optional('virtualCluster'): setType('virtualCluster', str),
203
        Optional('paiStorageConfigName'): setType('paiStorageConfigName', str),
SparkSnail's avatar
SparkSnail committed
204
        Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
205
206
207
    }
}

208
pai_config_schema = {
209
    'paiConfig': {
210
        'userName': setType('userName', str),
211
        Or('passWord', 'token', only_one=True): str,
212
        'host': setType('host', str),
213
214
215
216
217
218
219
        Optional('reuse'): setType('reuse', bool),
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
        Optional('memoryMB'): setType('memoryMB', int),
        Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
        Optional('useActiveGpu'): setType('useActiveGpu', bool),
    }
220
221
}

George Cheng's avatar
George Cheng committed
222
dlts_trial_schema = {
223
    'trial': {
George Cheng's avatar
George Cheng committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
        'image': setType('image', str),
    }
}

dlts_config_schema = {
    'dltsConfig': {
        'dashboard': setType('dashboard', str),

        Optional('cluster'): setType('cluster', str),
        Optional('team'): setType('team', str),

        Optional('email'): setType('email', str),
        Optional('password'): setType('password', str),
    }
}

SparkSnail's avatar
SparkSnail committed
243
aml_trial_schema = {
244
    'trial': {
SparkSnail's avatar
SparkSnail committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        'codeDir': setPathCheck('codeDir'),
        'command': setType('command', str),
        'image': setType('image', str),
        'computeTarget': setType('computeTarget', str)
    }
}

aml_config_schema = {
    'amlConfig': {
        'subscriptionId': setType('subscriptionId', str),
        'resourceGroup': setType('resourceGroup', str),
        'workspaceName': setType('workspaceName', str),
    }
}

260
kubeflow_trial_schema = {
261
    'trial': {
262
        'codeDir':  setPathCheck('codeDir'),
263
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
264
        Optional('ps'): {
265
266
267
268
269
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
270
271
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
272
        },
273
        Optional('master'): {
274
275
276
277
278
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
279
280
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
281
        },
282
        Optional('worker'): {
283
284
285
286
287
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
288
289
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
290
        }
291
292
293
294
    }
}

kubeflow_config_schema = {
295
    'kubeflowConfig': Or({
296
297
298
        'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
        'apiVersion': setType('apiVersion', str),
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
299
        'nfs': {
300
301
            'server': setType('server', str),
            'path': setType('path', str)
302
        }
chicm-ms's avatar
chicm-ms committed
303
    }, {
304
305
306
        'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
        'apiVersion': setType('apiVersion', str),
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
SparkSnail's avatar
SparkSnail committed
307
        'keyVault': {
308
309
310
311
            'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
                             error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
            'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
                        error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
SparkSnail's avatar
SparkSnail committed
312
313
        },
        'azureStorage': {
314
315
316
317
            'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
                               error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
            'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),
                              error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
318
319
        },
        Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
SparkSnail's avatar
SparkSnail committed
320
    })
321
322
}

323
frameworkcontroller_trial_schema = {
324
    'trial': {
325
        'codeDir':  setPathCheck('codeDir'),
326
        'taskRoles': [{
327
328
            'name': setType('name', str),
            'taskNum': setType('taskNum', int),
329
            'frameworkAttemptCompletionPolicy': {
330
331
                'minFailedTaskCount': setType('minFailedTaskCount', int),
                'minSucceededTaskCount': setType('minSucceededTaskCount', int),
332
            },
333
334
335
336
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
337
338
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
339
340
341
342
343
        }]
    }
}

frameworkcontroller_config_schema = {
344
    'frameworkcontrollerConfig': Or({
345
346
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
        Optional('serviceAccountName'): setType('serviceAccountName', str),
347
        'nfs': {
348
349
            'server': setType('server', str),
            'path': setType('path', str)
350
        }
chicm-ms's avatar
chicm-ms committed
351
    }, {
352
353
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
        Optional('serviceAccountName'): setType('serviceAccountName', str),
354
        'keyVault': {
355
356
357
358
            'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
                             error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
            'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
                        error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
359
360
        },
        'azureStorage': {
361
362
363
364
            'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
                               error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
            'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),
                              error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
365
366
        },
        Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
367
368
369
    })
}

370
machine_list_schema = {
371
    'machineList': [Or(
372
        {
chicm-ms's avatar
chicm-ms committed
373
374
375
376
377
378
379
380
            'ip': setType('ip', str),
            Optional('port'): setNumberRange('port', int, 1, 65535),
            'username': setType('username', str),
            'sshKeyPath': setPathCheck('sshKeyPath'),
            Optional('passphrase'): setType('passphrase', str),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
            Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
            Optional('useActiveGpu'): setType('useActiveGpu', bool)
381
382
383
384
385
386
387
388
389
        },
        {
            'ip': setType('ip', str),
            Optional('port'): setNumberRange('port', int, 1, 65535),
            'username': setType('username', str),
            'passwd': setType('passwd', str),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
            Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
            Optional('useActiveGpu'): setType('useActiveGpu', bool)
chicm-ms's avatar
chicm-ms committed
390
        })]
391
}
392

chicm-ms's avatar
chicm-ms committed
393
394
395
396
397
398
399
training_service_schema_dict = {
    'local': Schema({**common_schema, **common_trial_schema}),
    'remote': Schema({**common_schema, **common_trial_schema, **machine_list_schema}),
    'pai': Schema({**common_schema, **pai_trial_schema, **pai_config_schema}),
    'paiYarn': Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema}),
    'kubeflow': Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema}),
    'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}),
SparkSnail's avatar
SparkSnail committed
400
    'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}),
chicm-ms's avatar
chicm-ms committed
401
402
403
    'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
}

404

chicm-ms's avatar
chicm-ms committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
class NNIConfigSchema:
    def validate(self, data):
        train_service = data['trainingServicePlatform']
        Schema(common_schema['trainingServicePlatform']).validate(train_service)
        train_service_schema = training_service_schema_dict[train_service]
        train_service_schema.validate(data)
        self.validate_extras(data)

    def validate_extras(self, experiment_config):
        self.validate_tuner_adivosr_assessor(experiment_config)
        self.validate_pai_trial_conifg(experiment_config)
        self.validate_kubeflow_operators(experiment_config)
        self.validate_eth0_device(experiment_config)

    def validate_tuner_adivosr_assessor(self, experiment_config):
        if experiment_config.get('advisor'):
            if experiment_config.get('assessor') or experiment_config.get('tuner'):
                raise SchemaError('advisor could not be set with assessor or tuner simultaneously!')
            self.validate_annotation_content(experiment_config, 'advisor', 'builtinAdvisorName')
        else:
            if not experiment_config.get('tuner'):
                raise SchemaError('Please provide tuner spec!')
            self.validate_annotation_content(experiment_config, 'tuner', 'builtinTunerName')

    def validate_search_space_content(self, experiment_config):
        '''Validate searchspace content,
        if the searchspace file is not json format or its values does not contain _type and _value which must be specified,
        it will not be a valid searchspace file'''
        try:
            search_space_content = json.load(open(experiment_config.get('searchSpacePath'), 'r'))
            for value in search_space_content.values():
                if not value.get('_type') or not value.get('_value'):
                    raise SchemaError('please use _type and _value to specify searchspace!')
        except Exception as e:
            raise SchemaError('searchspace file is not a valid json format! ' + str(e))
440

chicm-ms's avatar
chicm-ms committed
441
442
443
444
445
446
447
448
449
450
451
452
453
    def validate_kubeflow_operators(self, experiment_config):
        '''Validate whether the kubeflow operators are valid'''
        if experiment_config.get('kubeflowConfig'):
            if experiment_config.get('kubeflowConfig').get('operator') == 'tf-operator':
                if experiment_config.get('trial').get('master') is not None:
                    raise SchemaError('kubeflow with tf-operator can not set master')
                if experiment_config.get('trial').get('worker') is None:
                    raise SchemaError('kubeflow with tf-operator must set worker')
            elif experiment_config.get('kubeflowConfig').get('operator') == 'pytorch-operator':
                if experiment_config.get('trial').get('ps') is not None:
                    raise SchemaError('kubeflow with pytorch-operator can not set ps')
                if experiment_config.get('trial').get('master') is None:
                    raise SchemaError('kubeflow with pytorch-operator must set master')
454

chicm-ms's avatar
chicm-ms committed
455
456
457
458
459
460
461
462
463
            if experiment_config.get('kubeflowConfig').get('storage') == 'nfs':
                if experiment_config.get('kubeflowConfig').get('nfs') is None:
                    raise SchemaError('please set nfs configuration!')
            elif experiment_config.get('kubeflowConfig').get('storage') == 'azureStorage':
                if experiment_config.get('kubeflowConfig').get('azureStorage') is None:
                    raise SchemaError('please set azureStorage configuration!')
            elif experiment_config.get('kubeflowConfig').get('storage') is None:
                if experiment_config.get('kubeflowConfig').get('azureStorage'):
                    raise SchemaError('please set storage type!')
464

chicm-ms's avatar
chicm-ms committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    def validate_annotation_content(self, experiment_config, spec_key, builtin_name):
        '''
        Valid whether useAnnotation and searchSpacePath is coexist
        spec_key: 'advisor' or 'tuner'
        builtin_name: 'builtinAdvisorName' or 'builtinTunerName'
        '''
        if experiment_config.get('useAnnotation'):
            if experiment_config.get('searchSpacePath'):
                raise SchemaError('If you set useAnnotation=true, please leave searchSpacePath empty')
        else:
            # validate searchSpaceFile
            if experiment_config[spec_key].get(builtin_name) == 'NetworkMorphism':
                return
            if experiment_config[spec_key].get(builtin_name):
                if experiment_config.get('searchSpacePath') is None:
                    raise SchemaError('Please set searchSpacePath!')
                self.validate_search_space_content(experiment_config)
482

chicm-ms's avatar
chicm-ms committed
483
484
485
486
487
488
489
490
491
492
    def validate_pai_config_path(self, experiment_config):
        '''validate paiConfigPath field'''
        if experiment_config.get('trainingServicePlatform') == 'pai':
            if experiment_config.get('trial', {}).get('paiConfigPath'):
                # validate commands
                pai_config = get_yml_content(experiment_config['trial']['paiConfigPath'])
                taskRoles_dict = pai_config.get('taskRoles')
                if not taskRoles_dict:
                    raise SchemaError('Please set taskRoles in paiConfigPath config file!')
            else:
493
                pai_trial_fields_required_list = ['image', 'paiStorageConfigName', 'command']
chicm-ms's avatar
chicm-ms committed
494
495
496
497
                for trial_field in pai_trial_fields_required_list:
                    if experiment_config['trial'].get(trial_field) is None:
                        raise SchemaError('Please set {0} in trial configuration,\
                                    or set additional pai configuration file path in paiConfigPath!'.format(trial_field))
498
499
500
501
502
503
                pai_resource_fields_required_list = ['gpuNum', 'cpuNum', 'memoryMB']
                for required_field in pai_resource_fields_required_list:
                    if experiment_config['trial'].get(required_field) is None and \
                            experiment_config['paiConfig'].get(required_field) is None:
                        raise SchemaError('Please set {0} in trial or paiConfig configuration,\
                                    or set additional pai configuration file path in paiConfigPath!'.format(required_field))
George Cheng's avatar
George Cheng committed
504

chicm-ms's avatar
chicm-ms committed
505
506
507
508
    def validate_pai_trial_conifg(self, experiment_config):
        '''validate the trial config in pai platform'''
        if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
            if experiment_config.get('trial').get('shmMB') and \
509
                    experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']:
chicm-ms's avatar
chicm-ms committed
510
                raise SchemaError('shmMB should be no more than memoryMB!')
511
            # backward compatibility
chicm-ms's avatar
chicm-ms committed
512
513
514
515
516
517
518
519
            warning_information = '{0} is not supported in NNI anymore, please remove the field in config file!\
            please refer https://github.com/microsoft/nni/blob/master/docs/en_US/TrainingService/PaiMode.md#run-an-experiment\
            for the practices of how to get data and output model in trial code'
            if experiment_config.get('trial').get('dataDir'):
                print_warning(warning_information.format('dataDir'))
            if experiment_config.get('trial').get('outputDir'):
                print_warning(warning_information.format('outputDir'))
            self.validate_pai_config_path(experiment_config)
520

chicm-ms's avatar
chicm-ms committed
521
522
523
    def validate_eth0_device(self, experiment_config):
        '''validate whether the machine has eth0 device'''
        if experiment_config.get('trainingServicePlatform') not in ['local'] \
524
525
                and not experiment_config.get('nniManagerIp') \
                and 'eth0' not in netifaces.interfaces():
chicm-ms's avatar
chicm-ms committed
526
            raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!')