config_schema.py 24.3 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4

import os
chicm-ms's avatar
chicm-ms committed
5
6
7
8
import json
import netifaces
from schema import Schema, And, Optional, Regex, Or, SchemaError
from nni.package_utils import create_validator_instance, get_all_builtin_names, get_builtin_algo_meta
9
from .constants import SCHEMA_TYPE_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_PATH_ERROR
chicm-ms's avatar
chicm-ms committed
10
from .common_utils import get_yml_content, print_warning
11
12


chicm-ms's avatar
chicm-ms committed
13
def setType(key, valueType):
14
    '''check key type'''
chicm-ms's avatar
chicm-ms committed
15
    return And(valueType, error=SCHEMA_TYPE_ERROR % (key, valueType.__name__))
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

def setChoice(key, *args):
    '''check choice'''
    return And(lambda n: n in args, error=SCHEMA_RANGE_ERROR % (key, str(args)))

def setNumberRange(key, keyType, start, end):
    '''check number range'''
    return And(
        And(keyType, error=SCHEMA_TYPE_ERROR % (key, keyType.__name__)),
        And(lambda n: start <= n <= end, error=SCHEMA_RANGE_ERROR % (key, '(%s,%s)' % (start, end))),
    )

def setPathCheck(key):
    '''check if path exist'''
    return And(os.path.exists, error=SCHEMA_PATH_ERROR % key)
31

chicm-ms's avatar
chicm-ms committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
class AlgoSchema:
    """
    This class is the schema of 'tuner', 'assessor' and 'advisor' sections of experiment configuraion file.
    For example:
    AlgoSchema('tuner') creates the schema of tuner section.
    """
    def __init__(self, algo_type):
        """
        Parameters:
        -----------
        algo_type: str
            One of ['tuner', 'assessor', 'advisor'].
            'tuner': This AlgoSchema class create the schema of tuner section.
            'assessor': This AlgoSchema class create the schema of assessor section.
            'advisor': This AlgoSchema class create the schema of advisor section.
        """
        assert algo_type in ['tuner', 'assessor', 'advisor']
        self.algo_type = algo_type
        self.algo_schema = {
            Optional('codeDir'): setPathCheck('codeDir'),
            Optional('classFileName'): setType('classFileName', str),
            Optional('className'): setType('className', str),
            Optional('classArgs'): dict,
            Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
        }
        self.builtin_keys = {
            'tuner': 'builtinTunerName',
            'assessor': 'builtinAssessorName',
            'advisor': 'builtinAdvisorName'
        }
        self.builtin_name_schema = {}
        for k, n in self.builtin_keys.items():
            self.builtin_name_schema[k] = {Optional(n): setChoice(n, *get_all_builtin_names(k+'s'))}

        self.customized_keys = set(['codeDir', 'classFileName', 'className'])

    def validate_class_args(self, class_args, algo_type, builtin_name):
        if not builtin_name or not class_args:
            return
        meta = get_builtin_algo_meta(algo_type+'s', builtin_name)
        if meta and 'accept_class_args' in meta and meta['accept_class_args'] == False:
            raise SchemaError('classArgs is not allowed.')

        validator = create_validator_instance(algo_type+'s', builtin_name)
        if validator:
            try:
                validator.validate_class_args(**class_args)
            except Exception as e:
                raise SchemaError(str(e))

    def missing_customized_keys(self, data):
        return self.customized_keys - set(data.keys())

    def validate_extras(self, data, algo_type):
        builtin_key = self.builtin_keys[algo_type]
        if (builtin_key in data) and (set(data.keys()) & self.customized_keys):
            raise SchemaError('{} and {} cannot be specified at the same time.'.format(
                builtin_key, set(data.keys()) & self.customized_keys
            ))

        if self.missing_customized_keys(data) and builtin_key not in data:
            raise SchemaError('Either customized {} ({}) or builtin {} ({}) must be set.'.format(
                algo_type, self.customized_keys, algo_type, builtin_key))

        if not self.missing_customized_keys(data):
            class_file_name = os.path.join(data['codeDir'], data['classFileName'])
            if not os.path.isfile(class_file_name):
                raise SchemaError('classFileName {} not found.'.format(class_file_name))

        builtin_name = data.get(builtin_key)
        class_args = data.get('classArgs')
        self.validate_class_args(class_args, algo_type, builtin_name)

    def validate(self, data):
        self.algo_schema.update(self.builtin_name_schema[self.algo_type])
        Schema(self.algo_schema).validate(data)
        self.validate_extras(data, self.algo_type)

111
common_schema = {
112
113
114
115
    'authorName': setType('authorName', str),
    'experimentName': setType('experimentName', str),
    Optional('description'): setType('description', str),
    'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
chicm-ms's avatar
chicm-ms committed
116
    Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
117
    Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
George Cheng's avatar
George Cheng committed
118
    'trainingServicePlatform': setChoice(
SparkSnail's avatar
SparkSnail committed
119
        'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'),
120
121
122
123
124
125
    Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
    Optional('multiPhase'): setType('multiPhase', bool),
    Optional('multiThread'): setType('multiThread', bool),
    Optional('nniManagerIp'): setType('nniManagerIp', str),
    Optional('logDir'): And(os.path.isdir, error=SCHEMA_PATH_ERROR % 'logDir'),
    Optional('debug'): setType('debug', bool),
126
    Optional('versionCheck'): setType('versionCheck', bool),
127
128
129
    Optional('logLevel'): setChoice('logLevel', 'trace', 'debug', 'info', 'warning', 'error', 'fatal'),
    Optional('logCollection'): setChoice('logCollection', 'http', 'none'),
    'useAnnotation': setType('useAnnotation', bool),
chicm-ms's avatar
chicm-ms committed
130
131
132
    Optional('tuner'): AlgoSchema('tuner'),
    Optional('advisor'): AlgoSchema('advisor'),
    Optional('assessor'): AlgoSchema('assessor'),
133
    Optional('localConfig'): {
134
135
136
        Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
        Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
        Optional('useActiveGpu'): setType('useActiveGpu', bool)
137
138
    }
}
139
140

common_trial_schema = {
chicm-ms's avatar
chicm-ms committed
141
142
143
144
145
    'trial':{
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode')
146
147
148
    }
}

149
pai_yarn_trial_schema = {
chicm-ms's avatar
chicm-ms committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    'trial':{
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
        'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
        'memoryMB': setType('memoryMB', int),
        'image': setType('image', str),
        Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'),
        Optional('shmMB'): setType('shmMB', int),
        Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
                            error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
        Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
                            error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
        Optional('virtualCluster'): setType('virtualCluster', str),
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
        Optional('portList'): [{
            "label": setType('label', str),
            "beginAt": setType('beginAt', int),
            "portNumber": setType('portNumber', int)
        }]
170
171
172
    }
}

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
pai_yarn_config_schema = {
    'paiYarnConfig': Or({
        'userName': setType('userName', str),
        'passWord': setType('passWord', str),
        'host': setType('host', str)
    }, {
        'userName': setType('userName', str),
        'token': setType('token', str),
        'host': setType('host', str)
    })
}


pai_trial_schema = {
    'trial':{
        'codeDir': setPathCheck('codeDir'),
        'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
        'containerNFSMountPath': setType('containerNFSMountPath', str),
191
        Optional('command'): setType('command', str),
SparkSnail's avatar
SparkSnail committed
192
193
194
195
196
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
        Optional('memoryMB'): setType('memoryMB', int),
        Optional('image'): setType('image', str),
        Optional('virtualCluster'): setType('virtualCluster', str),
197
        Optional('paiStorageConfigName'): setType('paiStorageConfigName', str),
SparkSnail's avatar
SparkSnail committed
198
        Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
199
200
201
    }
}

202
pai_config_schema = {
203
    'paiConfig': Or({
204
205
        'userName': setType('userName', str),
        'passWord': setType('passWord', str),
206
207
        'host': setType('host', str),
        Optional('reuse'): setType('reuse', bool)
208
209
210
    }, {
        'userName': setType('userName', str),
        'token': setType('token', str),
211
212
        'host': setType('host', str),
        Optional('reuse'): setType('reuse', bool)
213
    })
214
215
}

George Cheng's avatar
George Cheng committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
dlts_trial_schema = {
    'trial':{
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
        'image': setType('image', str),
    }
}

dlts_config_schema = {
    'dltsConfig': {
        'dashboard': setType('dashboard', str),

        Optional('cluster'): setType('cluster', str),
        Optional('team'): setType('team', str),

        Optional('email'): setType('email', str),
        Optional('password'): setType('password', str),
    }
}

SparkSnail's avatar
SparkSnail committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
aml_trial_schema = {
    'trial':{
        'codeDir': setPathCheck('codeDir'),
        'command': setType('command', str),
        'image': setType('image', str),
        'computeTarget': setType('computeTarget', str)
    }
}

aml_config_schema = {
    'amlConfig': {
        'subscriptionId': setType('subscriptionId', str),
        'resourceGroup': setType('resourceGroup', str),
        'workspaceName': setType('workspaceName', str),
    }
}

254
kubeflow_trial_schema = {
chicm-ms's avatar
chicm-ms committed
255
    'trial':{
256
        'codeDir':  setPathCheck('codeDir'),
257
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
258
        Optional('ps'): {
259
260
261
262
263
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
264
265
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
266
        },
267
        Optional('master'): {
268
269
270
271
272
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
273
274
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
275
        },
276
        Optional('worker'):{
277
278
279
280
281
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
282
283
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
284
        }
285
286
287
288
    }
}

kubeflow_config_schema = {
SparkSnail's avatar
SparkSnail committed
289
    'kubeflowConfig':Or({
290
291
292
        'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
        'apiVersion': setType('apiVersion', str),
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
293
        'nfs': {
294
295
            'server': setType('server', str),
            'path': setType('path', str)
296
        }
chicm-ms's avatar
chicm-ms committed
297
    }, {
298
299
300
        'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
        'apiVersion': setType('apiVersion', str),
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
SparkSnail's avatar
SparkSnail committed
301
        'keyVault': {
302
303
304
305
            'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
                         error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
            'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
                    error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
SparkSnail's avatar
SparkSnail committed
306
307
        },
        'azureStorage': {
308
309
310
311
            'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\
                           error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
            'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\
                          error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
312
313
        },
        Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
SparkSnail's avatar
SparkSnail committed
314
    })
315
316
}

317
318
frameworkcontroller_trial_schema = {
    'trial':{
319
        'codeDir':  setPathCheck('codeDir'),
320
        'taskRoles': [{
321
322
            'name': setType('name', str),
            'taskNum': setType('taskNum', int),
323
            'frameworkAttemptCompletionPolicy': {
324
325
                'minFailedTaskCount': setType('minFailedTaskCount', int),
                'minSucceededTaskCount': setType('minSucceededTaskCount', int),
326
            },
327
328
329
330
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
331
332
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
333
334
335
336
337
338
        }]
    }
}

frameworkcontroller_config_schema = {
    'frameworkcontrollerConfig':Or({
339
340
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
        Optional('serviceAccountName'): setType('serviceAccountName', str),
341
        'nfs': {
342
343
            'server': setType('server', str),
            'path': setType('path', str)
344
        }
chicm-ms's avatar
chicm-ms committed
345
    }, {
346
347
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
        Optional('serviceAccountName'): setType('serviceAccountName', str),
348
        'keyVault': {
349
350
351
352
            'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
                         error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
            'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
                    error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
353
354
        },
        'azureStorage': {
355
356
357
358
            'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\
                           error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
            'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\
                          error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
359
360
        },
        Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
361
362
363
    })
}

364
machine_list_schema = {
chicm-ms's avatar
chicm-ms committed
365
    'machineList':[Or(
366
        {
chicm-ms's avatar
chicm-ms committed
367
368
369
370
371
372
373
374
            'ip': setType('ip', str),
            Optional('port'): setNumberRange('port', int, 1, 65535),
            'username': setType('username', str),
            'sshKeyPath': setPathCheck('sshKeyPath'),
            Optional('passphrase'): setType('passphrase', str),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
            Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
            Optional('useActiveGpu'): setType('useActiveGpu', bool)
375
376
377
378
379
380
381
382
383
        },
        {
            'ip': setType('ip', str),
            Optional('port'): setNumberRange('port', int, 1, 65535),
            'username': setType('username', str),
            'passwd': setType('passwd', str),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
            Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
            Optional('useActiveGpu'): setType('useActiveGpu', bool)
chicm-ms's avatar
chicm-ms committed
384
        })]
385
}
386

chicm-ms's avatar
chicm-ms committed
387
388
389
390
391
392
393
training_service_schema_dict = {
    'local': Schema({**common_schema, **common_trial_schema}),
    'remote': Schema({**common_schema, **common_trial_schema, **machine_list_schema}),
    'pai': Schema({**common_schema, **pai_trial_schema, **pai_config_schema}),
    'paiYarn': Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema}),
    'kubeflow': Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema}),
    'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}),
SparkSnail's avatar
SparkSnail committed
394
    'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}),
chicm-ms's avatar
chicm-ms committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
    'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
}

class NNIConfigSchema:
    def validate(self, data):
        train_service = data['trainingServicePlatform']
        Schema(common_schema['trainingServicePlatform']).validate(train_service)
        train_service_schema = training_service_schema_dict[train_service]
        train_service_schema.validate(data)
        self.validate_extras(data)

    def validate_extras(self, experiment_config):
        self.validate_tuner_adivosr_assessor(experiment_config)
        self.validate_pai_trial_conifg(experiment_config)
        self.validate_kubeflow_operators(experiment_config)
        self.validate_eth0_device(experiment_config)

    def validate_tuner_adivosr_assessor(self, experiment_config):
        if experiment_config.get('advisor'):
            if experiment_config.get('assessor') or experiment_config.get('tuner'):
                raise SchemaError('advisor could not be set with assessor or tuner simultaneously!')
            self.validate_annotation_content(experiment_config, 'advisor', 'builtinAdvisorName')
        else:
            if not experiment_config.get('tuner'):
                raise SchemaError('Please provide tuner spec!')
            self.validate_annotation_content(experiment_config, 'tuner', 'builtinTunerName')

    def validate_search_space_content(self, experiment_config):
        '''Validate searchspace content,
        if the searchspace file is not json format or its values does not contain _type and _value which must be specified,
        it will not be a valid searchspace file'''
        try:
            search_space_content = json.load(open(experiment_config.get('searchSpacePath'), 'r'))
            for value in search_space_content.values():
                if not value.get('_type') or not value.get('_value'):
                    raise SchemaError('please use _type and _value to specify searchspace!')
        except Exception as e:
            raise SchemaError('searchspace file is not a valid json format! ' + str(e))
433

chicm-ms's avatar
chicm-ms committed
434
435
436
437
438
439
440
441
442
443
444
445
446
    def validate_kubeflow_operators(self, experiment_config):
        '''Validate whether the kubeflow operators are valid'''
        if experiment_config.get('kubeflowConfig'):
            if experiment_config.get('kubeflowConfig').get('operator') == 'tf-operator':
                if experiment_config.get('trial').get('master') is not None:
                    raise SchemaError('kubeflow with tf-operator can not set master')
                if experiment_config.get('trial').get('worker') is None:
                    raise SchemaError('kubeflow with tf-operator must set worker')
            elif experiment_config.get('kubeflowConfig').get('operator') == 'pytorch-operator':
                if experiment_config.get('trial').get('ps') is not None:
                    raise SchemaError('kubeflow with pytorch-operator can not set ps')
                if experiment_config.get('trial').get('master') is None:
                    raise SchemaError('kubeflow with pytorch-operator must set master')
447

chicm-ms's avatar
chicm-ms committed
448
449
450
451
452
453
454
455
456
            if experiment_config.get('kubeflowConfig').get('storage') == 'nfs':
                if experiment_config.get('kubeflowConfig').get('nfs') is None:
                    raise SchemaError('please set nfs configuration!')
            elif experiment_config.get('kubeflowConfig').get('storage') == 'azureStorage':
                if experiment_config.get('kubeflowConfig').get('azureStorage') is None:
                    raise SchemaError('please set azureStorage configuration!')
            elif experiment_config.get('kubeflowConfig').get('storage') is None:
                if experiment_config.get('kubeflowConfig').get('azureStorage'):
                    raise SchemaError('please set storage type!')
457

chicm-ms's avatar
chicm-ms committed
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    def validate_annotation_content(self, experiment_config, spec_key, builtin_name):
        '''
        Valid whether useAnnotation and searchSpacePath is coexist
        spec_key: 'advisor' or 'tuner'
        builtin_name: 'builtinAdvisorName' or 'builtinTunerName'
        '''
        if experiment_config.get('useAnnotation'):
            if experiment_config.get('searchSpacePath'):
                raise SchemaError('If you set useAnnotation=true, please leave searchSpacePath empty')
        else:
            # validate searchSpaceFile
            if experiment_config[spec_key].get(builtin_name) == 'NetworkMorphism':
                return
            if experiment_config[spec_key].get(builtin_name):
                if experiment_config.get('searchSpacePath') is None:
                    raise SchemaError('Please set searchSpacePath!')
                self.validate_search_space_content(experiment_config)
475

chicm-ms's avatar
chicm-ms committed
476
477
478
479
480
481
482
483
484
485
    def validate_pai_config_path(self, experiment_config):
        '''validate paiConfigPath field'''
        if experiment_config.get('trainingServicePlatform') == 'pai':
            if experiment_config.get('trial', {}).get('paiConfigPath'):
                # validate commands
                pai_config = get_yml_content(experiment_config['trial']['paiConfigPath'])
                taskRoles_dict = pai_config.get('taskRoles')
                if not taskRoles_dict:
                    raise SchemaError('Please set taskRoles in paiConfigPath config file!')
            else:
SparkSnail's avatar
SparkSnail committed
486
                pai_trial_fields_required_list = ['image', 'gpuNum', 'cpuNum', 'memoryMB', 'paiStorageConfigName', 'command']
chicm-ms's avatar
chicm-ms committed
487
488
489
490
                for trial_field in pai_trial_fields_required_list:
                    if experiment_config['trial'].get(trial_field) is None:
                        raise SchemaError('Please set {0} in trial configuration,\
                                    or set additional pai configuration file path in paiConfigPath!'.format(trial_field))
George Cheng's avatar
George Cheng committed
491

chicm-ms's avatar
chicm-ms committed
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    def validate_pai_trial_conifg(self, experiment_config):
        '''validate the trial config in pai platform'''
        if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
            if experiment_config.get('trial').get('shmMB') and \
            experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']:
                raise SchemaError('shmMB should be no more than memoryMB!')
            #backward compatibility
            warning_information = '{0} is not supported in NNI anymore, please remove the field in config file!\
            please refer https://github.com/microsoft/nni/blob/master/docs/en_US/TrainingService/PaiMode.md#run-an-experiment\
            for the practices of how to get data and output model in trial code'
            if experiment_config.get('trial').get('dataDir'):
                print_warning(warning_information.format('dataDir'))
            if experiment_config.get('trial').get('outputDir'):
                print_warning(warning_information.format('outputDir'))
            self.validate_pai_config_path(experiment_config)
507

chicm-ms's avatar
chicm-ms committed
508
509
510
511
512
513
    def validate_eth0_device(self, experiment_config):
        '''validate whether the machine has eth0 device'''
        if experiment_config.get('trainingServicePlatform') not in ['local'] \
        and not experiment_config.get('nniManagerIp') \
        and 'eth0' not in netifaces.interfaces():
            raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!')