config_schema.py 25.3 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4

import os
chicm-ms's avatar
chicm-ms committed
5
6
7
8
import json
import netifaces
from schema import Schema, And, Optional, Regex, Or, SchemaError
from nni.package_utils import create_validator_instance, get_all_builtin_names, get_builtin_algo_meta
9
from .constants import SCHEMA_TYPE_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_PATH_ERROR
chicm-ms's avatar
chicm-ms committed
10
from .common_utils import get_yml_content, print_warning
11
12


chicm-ms's avatar
chicm-ms committed
13
def setType(key, valueType):
14
    '''check key type'''
chicm-ms's avatar
chicm-ms committed
15
    return And(valueType, error=SCHEMA_TYPE_ERROR % (key, valueType.__name__))
16

17

18
19
20
21
def setChoice(key, *args):
    '''check choice'''
    return And(lambda n: n in args, error=SCHEMA_RANGE_ERROR % (key, str(args)))

22

23
24
25
26
27
28
29
def setNumberRange(key, keyType, start, end):
    '''check number range'''
    return And(
        And(keyType, error=SCHEMA_TYPE_ERROR % (key, keyType.__name__)),
        And(lambda n: start <= n <= end, error=SCHEMA_RANGE_ERROR % (key, '(%s,%s)' % (start, end))),
    )

30

31
32
33
def setPathCheck(key):
    '''check if path exist'''
    return And(os.path.exists, error=SCHEMA_PATH_ERROR % key)
34

35

chicm-ms's avatar
chicm-ms committed
36
37
38
39
40
41
class AlgoSchema:
    """
    This class is the schema of 'tuner', 'assessor' and 'advisor' sections of experiment configuraion file.
    For example:
    AlgoSchema('tuner') creates the schema of tuner section.
    """
42

chicm-ms's avatar
chicm-ms committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    def __init__(self, algo_type):
        """
        Parameters:
        -----------
        algo_type: str
            One of ['tuner', 'assessor', 'advisor'].
            'tuner': This AlgoSchema class create the schema of tuner section.
            'assessor': This AlgoSchema class create the schema of assessor section.
            'advisor': This AlgoSchema class create the schema of advisor section.
        """
        assert algo_type in ['tuner', 'assessor', 'advisor']
        self.algo_type = algo_type
        self.algo_schema = {
            Optional('codeDir'): setPathCheck('codeDir'),
            Optional('classFileName'): setType('classFileName', str),
            Optional('className'): setType('className', str),
            Optional('classArgs'): dict,
            Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
        }
        self.builtin_keys = {
            'tuner': 'builtinTunerName',
            'assessor': 'builtinAssessorName',
            'advisor': 'builtinAdvisorName'
        }
        self.builtin_name_schema = {}
        for k, n in self.builtin_keys.items():
            self.builtin_name_schema[k] = {Optional(n): setChoice(n, *get_all_builtin_names(k+'s'))}

        self.customized_keys = set(['codeDir', 'classFileName', 'className'])

    def validate_class_args(self, class_args, algo_type, builtin_name):
        if not builtin_name or not class_args:
            return
        meta = get_builtin_algo_meta(algo_type+'s', builtin_name)
        if meta and 'accept_class_args' in meta and meta['accept_class_args'] == False:
            raise SchemaError('classArgs is not allowed.')

        validator = create_validator_instance(algo_type+'s', builtin_name)
        if validator:
            try:
                validator.validate_class_args(**class_args)
            except Exception as e:
                raise SchemaError(str(e))

    def missing_customized_keys(self, data):
        return self.customized_keys - set(data.keys())

    def validate_extras(self, data, algo_type):
        builtin_key = self.builtin_keys[algo_type]
        if (builtin_key in data) and (set(data.keys()) & self.customized_keys):
            raise SchemaError('{} and {} cannot be specified at the same time.'.format(
                builtin_key, set(data.keys()) & self.customized_keys
            ))

        if self.missing_customized_keys(data) and builtin_key not in data:
            raise SchemaError('Either customized {} ({}) or builtin {} ({}) must be set.'.format(
                algo_type, self.customized_keys, algo_type, builtin_key))

        if not self.missing_customized_keys(data):
            class_file_name = os.path.join(data['codeDir'], data['classFileName'])
            if not os.path.isfile(class_file_name):
                raise SchemaError('classFileName {} not found.'.format(class_file_name))

        builtin_name = data.get(builtin_key)
        class_args = data.get('classArgs')
        self.validate_class_args(class_args, algo_type, builtin_name)

    def validate(self, data):
        self.algo_schema.update(self.builtin_name_schema[self.algo_type])
        Schema(self.algo_schema).validate(data)
        self.validate_extras(data, self.algo_type)

116

117
common_schema = {
118
119
120
121
    'authorName': setType('authorName', str),
    'experimentName': setType('experimentName', str),
    Optional('description'): setType('description', str),
    'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
chicm-ms's avatar
chicm-ms committed
122
    Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
123
    Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
George Cheng's avatar
George Cheng committed
124
    'trainingServicePlatform': setChoice(
SparkSnail's avatar
SparkSnail committed
125
        'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'),
126
127
128
129
130
131
    Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
    Optional('multiPhase'): setType('multiPhase', bool),
    Optional('multiThread'): setType('multiThread', bool),
    Optional('nniManagerIp'): setType('nniManagerIp', str),
    Optional('logDir'): And(os.path.isdir, error=SCHEMA_PATH_ERROR % 'logDir'),
    Optional('debug'): setType('debug', bool),
132
    Optional('versionCheck'): setType('versionCheck', bool),
133
134
135
    Optional('logLevel'): setChoice('logLevel', 'trace', 'debug', 'info', 'warning', 'error', 'fatal'),
    Optional('logCollection'): setChoice('logCollection', 'http', 'none'),
    'useAnnotation': setType('useAnnotation', bool),
chicm-ms's avatar
chicm-ms committed
136
137
138
    Optional('tuner'): AlgoSchema('tuner'),
    Optional('advisor'): AlgoSchema('advisor'),
    Optional('assessor'): AlgoSchema('assessor'),
139
    Optional('localConfig'): {
140
141
142
        Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
        Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
        Optional('useActiveGpu'): setType('useActiveGpu', bool)
143
144
    }
}
145
146

common_trial_schema = {
147
    'trial': {
chicm-ms's avatar
chicm-ms committed
148
149
150
151
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode')
152
153
154
    }
}

155
pai_yarn_trial_schema = {
156
    'trial': {
chicm-ms's avatar
chicm-ms committed
157
158
159
160
161
162
163
164
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
        'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
        'memoryMB': setType('memoryMB', int),
        'image': setType('image', str),
        Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'),
        Optional('shmMB'): setType('shmMB', int),
165
166
167
168
        Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
                                 error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
        Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
                                   error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
chicm-ms's avatar
chicm-ms committed
169
170
171
172
173
174
175
        Optional('virtualCluster'): setType('virtualCluster', str),
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
        Optional('portList'): [{
            "label": setType('label', str),
            "beginAt": setType('beginAt', int),
            "portNumber": setType('portNumber', int)
        }]
176
177
178
    }
}

179
180
181
182
183
184
185
186
187
188
189
190
191
192
pai_yarn_config_schema = {
    'paiYarnConfig': Or({
        'userName': setType('userName', str),
        'passWord': setType('passWord', str),
        'host': setType('host', str)
    }, {
        'userName': setType('userName', str),
        'token': setType('token', str),
        'host': setType('host', str)
    })
}


pai_trial_schema = {
193
    'trial': {
194
195
196
        'codeDir': setPathCheck('codeDir'),
        'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
        'containerNFSMountPath': setType('containerNFSMountPath', str),
197
        Optional('command'): setType('command', str),
SparkSnail's avatar
SparkSnail committed
198
199
200
201
202
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
        Optional('memoryMB'): setType('memoryMB', int),
        Optional('image'): setType('image', str),
        Optional('virtualCluster'): setType('virtualCluster', str),
203
        Optional('paiStorageConfigName'): setType('paiStorageConfigName', str),
SparkSnail's avatar
SparkSnail committed
204
        Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
205
206
207
    }
}

208
pai_config_schema = {
209
    'paiConfig': {
210
        'userName': setType('userName', str),
211
        Or('passWord', 'token', only_one=True): str,
212
        'host': setType('host', str),
213
214
215
216
217
218
219
        Optional('reuse'): setType('reuse', bool),
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
        Optional('memoryMB'): setType('memoryMB', int),
        Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
        Optional('useActiveGpu'): setType('useActiveGpu', bool),
    }
220
221
}

George Cheng's avatar
George Cheng committed
222
dlts_trial_schema = {
223
    'trial': {
George Cheng's avatar
George Cheng committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
        'image': setType('image', str),
    }
}

dlts_config_schema = {
    'dltsConfig': {
        'dashboard': setType('dashboard', str),

        Optional('cluster'): setType('cluster', str),
        Optional('team'): setType('team', str),

        Optional('email'): setType('email', str),
        Optional('password'): setType('password', str),
    }
}

SparkSnail's avatar
SparkSnail committed
243
aml_trial_schema = {
244
    'trial': {
SparkSnail's avatar
SparkSnail committed
245
246
247
        'codeDir': setPathCheck('codeDir'),
        'command': setType('command', str),
        'image': setType('image', str),
248
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
SparkSnail's avatar
SparkSnail committed
249
250
251
252
253
254
255
256
    }
}

aml_config_schema = {
    'amlConfig': {
        'subscriptionId': setType('subscriptionId', str),
        'resourceGroup': setType('resourceGroup', str),
        'workspaceName': setType('workspaceName', str),
257
258
259
        'computeTarget': setType('computeTarget', str),
        Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
        Optional('useActiveGpu'): setType('useActiveGpu', bool),
SparkSnail's avatar
SparkSnail committed
260
261
262
    }
}

263
kubeflow_trial_schema = {
264
    'trial': {
265
        'codeDir':  setPathCheck('codeDir'),
266
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
267
        Optional('ps'): {
268
269
270
271
272
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
273
274
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
275
        },
276
        Optional('master'): {
277
278
279
280
281
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
282
283
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
284
        },
285
        Optional('worker'): {
286
287
288
289
290
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
291
292
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
293
        }
294
295
296
297
    }
}

kubeflow_config_schema = {
298
    'kubeflowConfig': Or({
299
300
301
        'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
        'apiVersion': setType('apiVersion', str),
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
302
        'nfs': {
303
304
            'server': setType('server', str),
            'path': setType('path', str)
305
        }
chicm-ms's avatar
chicm-ms committed
306
    }, {
307
308
309
        'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
        'apiVersion': setType('apiVersion', str),
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
SparkSnail's avatar
SparkSnail committed
310
        'keyVault': {
311
312
313
314
            'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
                             error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
            'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
                        error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
SparkSnail's avatar
SparkSnail committed
315
316
        },
        'azureStorage': {
317
318
319
320
            'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
                               error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
            'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),
                              error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
321
322
        },
        Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
SparkSnail's avatar
SparkSnail committed
323
    })
324
325
}

326
frameworkcontroller_trial_schema = {
327
    'trial': {
328
        'codeDir':  setPathCheck('codeDir'),
329
        'taskRoles': [{
330
331
            'name': setType('name', str),
            'taskNum': setType('taskNum', int),
332
            'frameworkAttemptCompletionPolicy': {
333
334
                'minFailedTaskCount': setType('minFailedTaskCount', int),
                'minSucceededTaskCount': setType('minSucceededTaskCount', int),
335
            },
336
337
338
339
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
340
341
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
342
343
344
345
346
        }]
    }
}

frameworkcontroller_config_schema = {
347
    'frameworkcontrollerConfig': Or({
348
349
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
        Optional('serviceAccountName'): setType('serviceAccountName', str),
350
        'nfs': {
351
352
            'server': setType('server', str),
            'path': setType('path', str)
353
        }
chicm-ms's avatar
chicm-ms committed
354
    }, {
355
356
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
        Optional('serviceAccountName'): setType('serviceAccountName', str),
357
        'keyVault': {
358
359
360
361
            'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
                             error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
            'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
                        error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
362
363
        },
        'azureStorage': {
364
365
366
367
            'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
                               error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
            'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),
                              error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
368
369
        },
        Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
370
371
372
    })
}

373
machine_list_schema = {
374
    'machineList': [Or(
375
        {
chicm-ms's avatar
chicm-ms committed
376
377
378
379
380
381
382
383
            'ip': setType('ip', str),
            Optional('port'): setNumberRange('port', int, 1, 65535),
            'username': setType('username', str),
            'sshKeyPath': setPathCheck('sshKeyPath'),
            Optional('passphrase'): setType('passphrase', str),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
            Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
            Optional('useActiveGpu'): setType('useActiveGpu', bool)
384
385
386
387
388
389
390
391
392
        },
        {
            'ip': setType('ip', str),
            Optional('port'): setNumberRange('port', int, 1, 65535),
            'username': setType('username', str),
            'passwd': setType('passwd', str),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
            Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
            Optional('useActiveGpu'): setType('useActiveGpu', bool)
chicm-ms's avatar
chicm-ms committed
393
        })]
394
}
395

chicm-ms's avatar
chicm-ms committed
396
397
398
399
400
401
402
training_service_schema_dict = {
    'local': Schema({**common_schema, **common_trial_schema}),
    'remote': Schema({**common_schema, **common_trial_schema, **machine_list_schema}),
    'pai': Schema({**common_schema, **pai_trial_schema, **pai_config_schema}),
    'paiYarn': Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema}),
    'kubeflow': Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema}),
    'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}),
SparkSnail's avatar
SparkSnail committed
403
    'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}),
chicm-ms's avatar
chicm-ms committed
404
405
406
    'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
}

407

chicm-ms's avatar
chicm-ms committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
class NNIConfigSchema:
    def validate(self, data):
        train_service = data['trainingServicePlatform']
        Schema(common_schema['trainingServicePlatform']).validate(train_service)
        train_service_schema = training_service_schema_dict[train_service]
        train_service_schema.validate(data)
        self.validate_extras(data)

    def validate_extras(self, experiment_config):
        self.validate_tuner_adivosr_assessor(experiment_config)
        self.validate_pai_trial_conifg(experiment_config)
        self.validate_kubeflow_operators(experiment_config)
        self.validate_eth0_device(experiment_config)

    def validate_tuner_adivosr_assessor(self, experiment_config):
        if experiment_config.get('advisor'):
            if experiment_config.get('assessor') or experiment_config.get('tuner'):
                raise SchemaError('advisor could not be set with assessor or tuner simultaneously!')
            self.validate_annotation_content(experiment_config, 'advisor', 'builtinAdvisorName')
        else:
            if not experiment_config.get('tuner'):
                raise SchemaError('Please provide tuner spec!')
            self.validate_annotation_content(experiment_config, 'tuner', 'builtinTunerName')

    def validate_search_space_content(self, experiment_config):
        '''Validate searchspace content,
        if the searchspace file is not json format or its values does not contain _type and _value which must be specified,
        it will not be a valid searchspace file'''
        try:
            search_space_content = json.load(open(experiment_config.get('searchSpacePath'), 'r'))
            for value in search_space_content.values():
                if not value.get('_type') or not value.get('_value'):
                    raise SchemaError('please use _type and _value to specify searchspace!')
        except Exception as e:
            raise SchemaError('searchspace file is not a valid json format! ' + str(e))
443

chicm-ms's avatar
chicm-ms committed
444
445
446
447
448
449
450
451
452
453
454
455
456
    def validate_kubeflow_operators(self, experiment_config):
        '''Validate whether the kubeflow operators are valid'''
        if experiment_config.get('kubeflowConfig'):
            if experiment_config.get('kubeflowConfig').get('operator') == 'tf-operator':
                if experiment_config.get('trial').get('master') is not None:
                    raise SchemaError('kubeflow with tf-operator can not set master')
                if experiment_config.get('trial').get('worker') is None:
                    raise SchemaError('kubeflow with tf-operator must set worker')
            elif experiment_config.get('kubeflowConfig').get('operator') == 'pytorch-operator':
                if experiment_config.get('trial').get('ps') is not None:
                    raise SchemaError('kubeflow with pytorch-operator can not set ps')
                if experiment_config.get('trial').get('master') is None:
                    raise SchemaError('kubeflow with pytorch-operator must set master')
457

chicm-ms's avatar
chicm-ms committed
458
459
460
461
462
463
464
465
466
            if experiment_config.get('kubeflowConfig').get('storage') == 'nfs':
                if experiment_config.get('kubeflowConfig').get('nfs') is None:
                    raise SchemaError('please set nfs configuration!')
            elif experiment_config.get('kubeflowConfig').get('storage') == 'azureStorage':
                if experiment_config.get('kubeflowConfig').get('azureStorage') is None:
                    raise SchemaError('please set azureStorage configuration!')
            elif experiment_config.get('kubeflowConfig').get('storage') is None:
                if experiment_config.get('kubeflowConfig').get('azureStorage'):
                    raise SchemaError('please set storage type!')
467

chicm-ms's avatar
chicm-ms committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    def validate_annotation_content(self, experiment_config, spec_key, builtin_name):
        '''
        Valid whether useAnnotation and searchSpacePath is coexist
        spec_key: 'advisor' or 'tuner'
        builtin_name: 'builtinAdvisorName' or 'builtinTunerName'
        '''
        if experiment_config.get('useAnnotation'):
            if experiment_config.get('searchSpacePath'):
                raise SchemaError('If you set useAnnotation=true, please leave searchSpacePath empty')
        else:
            # validate searchSpaceFile
            if experiment_config[spec_key].get(builtin_name) == 'NetworkMorphism':
                return
            if experiment_config[spec_key].get(builtin_name):
                if experiment_config.get('searchSpacePath') is None:
                    raise SchemaError('Please set searchSpacePath!')
                self.validate_search_space_content(experiment_config)
485

chicm-ms's avatar
chicm-ms committed
486
487
488
489
490
491
492
493
494
495
    def validate_pai_config_path(self, experiment_config):
        '''validate paiConfigPath field'''
        if experiment_config.get('trainingServicePlatform') == 'pai':
            if experiment_config.get('trial', {}).get('paiConfigPath'):
                # validate commands
                pai_config = get_yml_content(experiment_config['trial']['paiConfigPath'])
                taskRoles_dict = pai_config.get('taskRoles')
                if not taskRoles_dict:
                    raise SchemaError('Please set taskRoles in paiConfigPath config file!')
            else:
496
                pai_trial_fields_required_list = ['image', 'paiStorageConfigName', 'command']
chicm-ms's avatar
chicm-ms committed
497
498
499
500
                for trial_field in pai_trial_fields_required_list:
                    if experiment_config['trial'].get(trial_field) is None:
                        raise SchemaError('Please set {0} in trial configuration,\
                                    or set additional pai configuration file path in paiConfigPath!'.format(trial_field))
501
502
503
504
505
506
                pai_resource_fields_required_list = ['gpuNum', 'cpuNum', 'memoryMB']
                for required_field in pai_resource_fields_required_list:
                    if experiment_config['trial'].get(required_field) is None and \
                            experiment_config['paiConfig'].get(required_field) is None:
                        raise SchemaError('Please set {0} in trial or paiConfig configuration,\
                                    or set additional pai configuration file path in paiConfigPath!'.format(required_field))
George Cheng's avatar
George Cheng committed
507

chicm-ms's avatar
chicm-ms committed
508
509
510
511
    def validate_pai_trial_conifg(self, experiment_config):
        '''validate the trial config in pai platform'''
        if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
            if experiment_config.get('trial').get('shmMB') and \
512
                    experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']:
chicm-ms's avatar
chicm-ms committed
513
                raise SchemaError('shmMB should be no more than memoryMB!')
514
            # backward compatibility
chicm-ms's avatar
chicm-ms committed
515
516
517
518
519
520
521
522
            warning_information = '{0} is not supported in NNI anymore, please remove the field in config file!\
            please refer https://github.com/microsoft/nni/blob/master/docs/en_US/TrainingService/PaiMode.md#run-an-experiment\
            for the practices of how to get data and output model in trial code'
            if experiment_config.get('trial').get('dataDir'):
                print_warning(warning_information.format('dataDir'))
            if experiment_config.get('trial').get('outputDir'):
                print_warning(warning_information.format('outputDir'))
            self.validate_pai_config_path(experiment_config)
523

chicm-ms's avatar
chicm-ms committed
524
525
526
    def validate_eth0_device(self, experiment_config):
        '''validate whether the machine has eth0 device'''
        if experiment_config.get('trainingServicePlatform') not in ['local'] \
527
528
                and not experiment_config.get('nniManagerIp') \
                and 'eth0' not in netifaces.interfaces():
chicm-ms's avatar
chicm-ms committed
529
            raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!')