config_schema.py 23.8 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4

import os
chicm-ms's avatar
chicm-ms committed
5
6
7
8
import json
import netifaces
from schema import Schema, And, Optional, Regex, Or, SchemaError
from nni.package_utils import create_validator_instance, get_all_builtin_names, get_builtin_algo_meta
9
from .constants import SCHEMA_TYPE_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_PATH_ERROR
chicm-ms's avatar
chicm-ms committed
10
from .common_utils import get_yml_content, print_warning
11
12


chicm-ms's avatar
chicm-ms committed
13
def setType(key, valueType):
14
    '''check key type'''
chicm-ms's avatar
chicm-ms committed
15
    return And(valueType, error=SCHEMA_TYPE_ERROR % (key, valueType.__name__))
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

def setChoice(key, *args):
    '''check choice'''
    return And(lambda n: n in args, error=SCHEMA_RANGE_ERROR % (key, str(args)))

def setNumberRange(key, keyType, start, end):
    '''check number range'''
    return And(
        And(keyType, error=SCHEMA_TYPE_ERROR % (key, keyType.__name__)),
        And(lambda n: start <= n <= end, error=SCHEMA_RANGE_ERROR % (key, '(%s,%s)' % (start, end))),
    )

def setPathCheck(key):
    '''check if path exist'''
    return And(os.path.exists, error=SCHEMA_PATH_ERROR % key)
31

chicm-ms's avatar
chicm-ms committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
class AlgoSchema:
    """
    This class is the schema of 'tuner', 'assessor' and 'advisor' sections of experiment configuraion file.
    For example:
    AlgoSchema('tuner') creates the schema of tuner section.
    """
    def __init__(self, algo_type):
        """
        Parameters:
        -----------
        algo_type: str
            One of ['tuner', 'assessor', 'advisor'].
            'tuner': This AlgoSchema class create the schema of tuner section.
            'assessor': This AlgoSchema class create the schema of assessor section.
            'advisor': This AlgoSchema class create the schema of advisor section.
        """
        assert algo_type in ['tuner', 'assessor', 'advisor']
        self.algo_type = algo_type
        self.algo_schema = {
            Optional('codeDir'): setPathCheck('codeDir'),
            Optional('classFileName'): setType('classFileName', str),
            Optional('className'): setType('className', str),
            Optional('classArgs'): dict,
            Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
        }
        self.builtin_keys = {
            'tuner': 'builtinTunerName',
            'assessor': 'builtinAssessorName',
            'advisor': 'builtinAdvisorName'
        }
        self.builtin_name_schema = {}
        for k, n in self.builtin_keys.items():
            self.builtin_name_schema[k] = {Optional(n): setChoice(n, *get_all_builtin_names(k+'s'))}

        self.customized_keys = set(['codeDir', 'classFileName', 'className'])

    def validate_class_args(self, class_args, algo_type, builtin_name):
        if not builtin_name or not class_args:
            return
        meta = get_builtin_algo_meta(algo_type+'s', builtin_name)
        if meta and 'accept_class_args' in meta and meta['accept_class_args'] == False:
            raise SchemaError('classArgs is not allowed.')

        validator = create_validator_instance(algo_type+'s', builtin_name)
        if validator:
            try:
                validator.validate_class_args(**class_args)
            except Exception as e:
                raise SchemaError(str(e))

    def missing_customized_keys(self, data):
        return self.customized_keys - set(data.keys())

    def validate_extras(self, data, algo_type):
        builtin_key = self.builtin_keys[algo_type]
        if (builtin_key in data) and (set(data.keys()) & self.customized_keys):
            raise SchemaError('{} and {} cannot be specified at the same time.'.format(
                builtin_key, set(data.keys()) & self.customized_keys
            ))

        if self.missing_customized_keys(data) and builtin_key not in data:
            raise SchemaError('Either customized {} ({}) or builtin {} ({}) must be set.'.format(
                algo_type, self.customized_keys, algo_type, builtin_key))

        if not self.missing_customized_keys(data):
            class_file_name = os.path.join(data['codeDir'], data['classFileName'])
            if not os.path.isfile(class_file_name):
                raise SchemaError('classFileName {} not found.'.format(class_file_name))

        builtin_name = data.get(builtin_key)
        class_args = data.get('classArgs')
        self.validate_class_args(class_args, algo_type, builtin_name)

    def validate(self, data):
        self.algo_schema.update(self.builtin_name_schema[self.algo_type])
        Schema(self.algo_schema).validate(data)
        self.validate_extras(data, self.algo_type)

111
common_schema = {
112
113
114
115
    'authorName': setType('authorName', str),
    'experimentName': setType('experimentName', str),
    Optional('description'): setType('description', str),
    'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
chicm-ms's avatar
chicm-ms committed
116
    Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
117
    Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
George Cheng's avatar
George Cheng committed
118
119
    'trainingServicePlatform': setChoice(
        'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'),
120
121
122
123
124
125
    Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
    Optional('multiPhase'): setType('multiPhase', bool),
    Optional('multiThread'): setType('multiThread', bool),
    Optional('nniManagerIp'): setType('nniManagerIp', str),
    Optional('logDir'): And(os.path.isdir, error=SCHEMA_PATH_ERROR % 'logDir'),
    Optional('debug'): setType('debug', bool),
126
    Optional('versionCheck'): setType('versionCheck', bool),
127
128
129
    Optional('logLevel'): setChoice('logLevel', 'trace', 'debug', 'info', 'warning', 'error', 'fatal'),
    Optional('logCollection'): setChoice('logCollection', 'http', 'none'),
    'useAnnotation': setType('useAnnotation', bool),
chicm-ms's avatar
chicm-ms committed
130
131
132
    Optional('tuner'): AlgoSchema('tuner'),
    Optional('advisor'): AlgoSchema('advisor'),
    Optional('assessor'): AlgoSchema('assessor'),
133
    Optional('localConfig'): {
134
135
136
        Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
        Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
        Optional('useActiveGpu'): setType('useActiveGpu', bool)
137
138
    }
}
139
140

common_trial_schema = {
chicm-ms's avatar
chicm-ms committed
141
142
143
144
145
    'trial':{
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode')
146
147
148
    }
}

149
pai_yarn_trial_schema = {
chicm-ms's avatar
chicm-ms committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    'trial':{
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
        'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
        'memoryMB': setType('memoryMB', int),
        'image': setType('image', str),
        Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'),
        Optional('shmMB'): setType('shmMB', int),
        Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
                            error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
        Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
                            error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
        Optional('virtualCluster'): setType('virtualCluster', str),
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
        Optional('portList'): [{
            "label": setType('label', str),
            "beginAt": setType('beginAt', int),
            "portNumber": setType('portNumber', int)
        }]
170
171
172
    }
}

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
pai_yarn_config_schema = {
    'paiYarnConfig': Or({
        'userName': setType('userName', str),
        'passWord': setType('passWord', str),
        'host': setType('host', str)
    }, {
        'userName': setType('userName', str),
        'token': setType('token', str),
        'host': setType('host', str)
    })
}


pai_trial_schema = {
    'trial':{
        'codeDir': setPathCheck('codeDir'),
        'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
        'containerNFSMountPath': setType('containerNFSMountPath', str),
191
        Optional('command'): setType('command', str),
SparkSnail's avatar
SparkSnail committed
192
193
194
195
196
        Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
        Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
        Optional('memoryMB'): setType('memoryMB', int),
        Optional('image'): setType('image', str),
        Optional('virtualCluster'): setType('virtualCluster', str),
197
        Optional('paiStorageConfigName'): setType('paiStorageConfigName', str),
SparkSnail's avatar
SparkSnail committed
198
        Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
199
200
201
    }
}

202
pai_config_schema = {
203
    'paiConfig': Or({
204
205
        'userName': setType('userName', str),
        'passWord': setType('passWord', str),
206
207
        'host': setType('host', str),
        Optional('reuse'): setType('reuse', bool)
208
209
210
    }, {
        'userName': setType('userName', str),
        'token': setType('token', str),
211
212
        'host': setType('host', str),
        Optional('reuse'): setType('reuse', bool)
213
    })
214
215
}

George Cheng's avatar
George Cheng committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
dlts_trial_schema = {
    'trial':{
        'command': setType('command', str),
        'codeDir': setPathCheck('codeDir'),
        'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
        'image': setType('image', str),
    }
}

dlts_config_schema = {
    'dltsConfig': {
        'dashboard': setType('dashboard', str),

        Optional('cluster'): setType('cluster', str),
        Optional('team'): setType('team', str),

        Optional('email'): setType('email', str),
        Optional('password'): setType('password', str),
    }
}

237
kubeflow_trial_schema = {
chicm-ms's avatar
chicm-ms committed
238
    'trial':{
239
        'codeDir':  setPathCheck('codeDir'),
240
        Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
241
        Optional('ps'): {
242
243
244
245
246
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
247
248
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
249
        },
250
        Optional('master'): {
251
252
253
254
255
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
256
257
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
258
        },
259
        Optional('worker'):{
260
261
262
263
264
            'replicas': setType('replicas', int),
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
265
266
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
267
        }
268
269
270
271
    }
}

kubeflow_config_schema = {
SparkSnail's avatar
SparkSnail committed
272
    'kubeflowConfig':Or({
273
274
275
        'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
        'apiVersion': setType('apiVersion', str),
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
276
        'nfs': {
277
278
            'server': setType('server', str),
            'path': setType('path', str)
279
        }
chicm-ms's avatar
chicm-ms committed
280
    }, {
281
282
283
        'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
        'apiVersion': setType('apiVersion', str),
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
SparkSnail's avatar
SparkSnail committed
284
        'keyVault': {
285
286
287
288
            'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
                         error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
            'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
                    error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
SparkSnail's avatar
SparkSnail committed
289
290
        },
        'azureStorage': {
291
292
293
294
            'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\
                           error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
            'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\
                          error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
295
296
        },
        Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
SparkSnail's avatar
SparkSnail committed
297
    })
298
299
}

300
301
frameworkcontroller_trial_schema = {
    'trial':{
302
        'codeDir':  setPathCheck('codeDir'),
303
        'taskRoles': [{
304
305
            'name': setType('name', str),
            'taskNum': setType('taskNum', int),
306
            'frameworkAttemptCompletionPolicy': {
307
308
                'minFailedTaskCount': setType('minFailedTaskCount', int),
                'minSucceededTaskCount': setType('minSucceededTaskCount', int),
309
            },
310
311
312
313
            'command': setType('command', str),
            'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
            'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
            'memoryMB': setType('memoryMB', int),
314
315
            'image': setType('image', str),
            Optional('privateRegistryAuthPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'privateRegistryAuthPath')
316
317
318
319
320
321
        }]
    }
}

frameworkcontroller_config_schema = {
    'frameworkcontrollerConfig':Or({
322
323
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
        Optional('serviceAccountName'): setType('serviceAccountName', str),
324
        'nfs': {
325
326
            'server': setType('server', str),
            'path': setType('path', str)
327
        }
chicm-ms's avatar
chicm-ms committed
328
    }, {
329
330
        Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
        Optional('serviceAccountName'): setType('serviceAccountName', str),
331
        'keyVault': {
332
333
334
335
            'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
                         error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
            'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
                    error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
336
337
        },
        'azureStorage': {
338
339
340
341
            'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\
                           error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
            'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\
                          error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
342
343
        },
        Optional('uploadRetryCount'): setNumberRange('uploadRetryCount', int, 1, 99999)
344
345
346
    })
}

347
machine_list_schema = {
chicm-ms's avatar
chicm-ms committed
348
    'machineList':[Or(
349
        {
chicm-ms's avatar
chicm-ms committed
350
351
352
353
354
355
356
357
            'ip': setType('ip', str),
            Optional('port'): setNumberRange('port', int, 1, 65535),
            'username': setType('username', str),
            'sshKeyPath': setPathCheck('sshKeyPath'),
            Optional('passphrase'): setType('passphrase', str),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
            Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
            Optional('useActiveGpu'): setType('useActiveGpu', bool)
358
359
360
361
362
363
364
365
366
        },
        {
            'ip': setType('ip', str),
            Optional('port'): setNumberRange('port', int, 1, 65535),
            'username': setType('username', str),
            'passwd': setType('passwd', str),
            Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
            Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
            Optional('useActiveGpu'): setType('useActiveGpu', bool)
chicm-ms's avatar
chicm-ms committed
367
        })]
368
}
369

chicm-ms's avatar
chicm-ms committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
training_service_schema_dict = {
    'local': Schema({**common_schema, **common_trial_schema}),
    'remote': Schema({**common_schema, **common_trial_schema, **machine_list_schema}),
    'pai': Schema({**common_schema, **pai_trial_schema, **pai_config_schema}),
    'paiYarn': Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema}),
    'kubeflow': Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema}),
    'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}),
    'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
}

class NNIConfigSchema:
    def validate(self, data):
        train_service = data['trainingServicePlatform']
        Schema(common_schema['trainingServicePlatform']).validate(train_service)
        train_service_schema = training_service_schema_dict[train_service]
        train_service_schema.validate(data)
        self.validate_extras(data)

    def validate_extras(self, experiment_config):
        self.validate_tuner_adivosr_assessor(experiment_config)
        self.validate_pai_trial_conifg(experiment_config)
        self.validate_kubeflow_operators(experiment_config)
        self.validate_eth0_device(experiment_config)

    def validate_tuner_adivosr_assessor(self, experiment_config):
        if experiment_config.get('advisor'):
            if experiment_config.get('assessor') or experiment_config.get('tuner'):
                raise SchemaError('advisor could not be set with assessor or tuner simultaneously!')
            self.validate_annotation_content(experiment_config, 'advisor', 'builtinAdvisorName')
        else:
            if not experiment_config.get('tuner'):
                raise SchemaError('Please provide tuner spec!')
            self.validate_annotation_content(experiment_config, 'tuner', 'builtinTunerName')

    def validate_search_space_content(self, experiment_config):
        '''Validate searchspace content,
        if the searchspace file is not json format or its values does not contain _type and _value which must be specified,
        it will not be a valid searchspace file'''
        try:
            search_space_content = json.load(open(experiment_config.get('searchSpacePath'), 'r'))
            for value in search_space_content.values():
                if not value.get('_type') or not value.get('_value'):
                    raise SchemaError('please use _type and _value to specify searchspace!')
        except Exception as e:
            raise SchemaError('searchspace file is not a valid json format! ' + str(e))
415

chicm-ms's avatar
chicm-ms committed
416
417
418
419
420
421
422
423
424
425
426
427
428
    def validate_kubeflow_operators(self, experiment_config):
        '''Validate whether the kubeflow operators are valid'''
        if experiment_config.get('kubeflowConfig'):
            if experiment_config.get('kubeflowConfig').get('operator') == 'tf-operator':
                if experiment_config.get('trial').get('master') is not None:
                    raise SchemaError('kubeflow with tf-operator can not set master')
                if experiment_config.get('trial').get('worker') is None:
                    raise SchemaError('kubeflow with tf-operator must set worker')
            elif experiment_config.get('kubeflowConfig').get('operator') == 'pytorch-operator':
                if experiment_config.get('trial').get('ps') is not None:
                    raise SchemaError('kubeflow with pytorch-operator can not set ps')
                if experiment_config.get('trial').get('master') is None:
                    raise SchemaError('kubeflow with pytorch-operator must set master')
429

chicm-ms's avatar
chicm-ms committed
430
431
432
433
434
435
436
437
438
            if experiment_config.get('kubeflowConfig').get('storage') == 'nfs':
                if experiment_config.get('kubeflowConfig').get('nfs') is None:
                    raise SchemaError('please set nfs configuration!')
            elif experiment_config.get('kubeflowConfig').get('storage') == 'azureStorage':
                if experiment_config.get('kubeflowConfig').get('azureStorage') is None:
                    raise SchemaError('please set azureStorage configuration!')
            elif experiment_config.get('kubeflowConfig').get('storage') is None:
                if experiment_config.get('kubeflowConfig').get('azureStorage'):
                    raise SchemaError('please set storage type!')
439

chicm-ms's avatar
chicm-ms committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    def validate_annotation_content(self, experiment_config, spec_key, builtin_name):
        '''
        Valid whether useAnnotation and searchSpacePath is coexist
        spec_key: 'advisor' or 'tuner'
        builtin_name: 'builtinAdvisorName' or 'builtinTunerName'
        '''
        if experiment_config.get('useAnnotation'):
            if experiment_config.get('searchSpacePath'):
                raise SchemaError('If you set useAnnotation=true, please leave searchSpacePath empty')
        else:
            # validate searchSpaceFile
            if experiment_config[spec_key].get(builtin_name) == 'NetworkMorphism':
                return
            if experiment_config[spec_key].get(builtin_name):
                if experiment_config.get('searchSpacePath') is None:
                    raise SchemaError('Please set searchSpacePath!')
                self.validate_search_space_content(experiment_config)
457

chicm-ms's avatar
chicm-ms committed
458
459
460
461
462
463
464
465
466
467
    def validate_pai_config_path(self, experiment_config):
        '''validate paiConfigPath field'''
        if experiment_config.get('trainingServicePlatform') == 'pai':
            if experiment_config.get('trial', {}).get('paiConfigPath'):
                # validate commands
                pai_config = get_yml_content(experiment_config['trial']['paiConfigPath'])
                taskRoles_dict = pai_config.get('taskRoles')
                if not taskRoles_dict:
                    raise SchemaError('Please set taskRoles in paiConfigPath config file!')
            else:
SparkSnail's avatar
SparkSnail committed
468
                pai_trial_fields_required_list = ['image', 'gpuNum', 'cpuNum', 'memoryMB', 'paiStorageConfigName', 'command']
chicm-ms's avatar
chicm-ms committed
469
470
471
472
                for trial_field in pai_trial_fields_required_list:
                    if experiment_config['trial'].get(trial_field) is None:
                        raise SchemaError('Please set {0} in trial configuration,\
                                    or set additional pai configuration file path in paiConfigPath!'.format(trial_field))
George Cheng's avatar
George Cheng committed
473

chicm-ms's avatar
chicm-ms committed
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    def validate_pai_trial_conifg(self, experiment_config):
        '''validate the trial config in pai platform'''
        if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
            if experiment_config.get('trial').get('shmMB') and \
            experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']:
                raise SchemaError('shmMB should be no more than memoryMB!')
            #backward compatibility
            warning_information = '{0} is not supported in NNI anymore, please remove the field in config file!\
            please refer https://github.com/microsoft/nni/blob/master/docs/en_US/TrainingService/PaiMode.md#run-an-experiment\
            for the practices of how to get data and output model in trial code'
            if experiment_config.get('trial').get('dataDir'):
                print_warning(warning_information.format('dataDir'))
            if experiment_config.get('trial').get('outputDir'):
                print_warning(warning_information.format('outputDir'))
            self.validate_pai_config_path(experiment_config)
489

chicm-ms's avatar
chicm-ms committed
490
491
492
493
494
495
    def validate_eth0_device(self, experiment_config):
        '''validate whether the machine has eth0 device'''
        if experiment_config.get('trainingServicePlatform') not in ['local'] \
        and not experiment_config.get('nniManagerIp') \
        and 'eth0' not in netifaces.interfaces():
            raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!')