Unverified Commit d779dc97 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Refactor nnictl, show error message gracefully (#986)

Advisor and assessor could not be used simultaneously.
Refactor config file error message.
Support scientific notation, like 1e-2
parent 19026041
......@@ -51,7 +51,7 @@ setup(
'json_tricks',
'numpy',
'psutil',
'pyyaml',
'ruamel.yaml',
'requests',
'scipy',
'schema',
......
......@@ -24,7 +24,7 @@ import json
import os
import subprocess
import requests
import yaml
import ruamel.yaml as yaml
EXPERIMENT_DONE_SIGNAL = '"Experiment done"'
......@@ -55,7 +55,7 @@ def remove_files(file_list):
def get_yml_content(file_path):
'''Load yaml file content'''
with open(file_path, 'r') as file:
return yaml.load(file)
return yaml.load(file, Loader=yaml.Loader)
def dump_yml_content(file_path, content):
'''Dump yaml file content'''
......
......@@ -19,7 +19,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import json
import yaml
import ruamel.yaml as yaml
import psutil
import socket
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO, COLOR_RED_FORMAT, COLOR_YELLOW_FORMAT
......@@ -28,7 +28,7 @@ def get_yml_content(file_path):
'''Load yaml file content'''
try:
with open(file_path, 'r') as file:
return yaml.load(file)
return yaml.load(file, Loader=yaml.Loader)
except TypeError as err:
print('Error: ', err)
return None
......
......@@ -20,267 +20,310 @@
import os
from schema import Schema, And, Use, Optional, Regex, Or
from .constants import SCHEMA_TYPE_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_PATH_ERROR
def setType(key, type):
'''check key type'''
return And(type, error=SCHEMA_TYPE_ERROR % (key, type.__name__))
def setChoice(key, *args):
'''check choice'''
return And(lambda n: n in args, error=SCHEMA_RANGE_ERROR % (key, str(args)))
def setNumberRange(key, keyType, start, end):
'''check number range'''
return And(
And(keyType, error=SCHEMA_TYPE_ERROR % (key, keyType.__name__)),
And(lambda n: start <= n <= end, error=SCHEMA_RANGE_ERROR % (key, '(%s,%s)' % (start, end))),
)
def setPathCheck(key):
'''check if path exist'''
return And(os.path.exists, error=SCHEMA_PATH_ERROR % key)
common_schema = {
'authorName': str,
'experimentName': str,
Optional('description'): str,
'trialConcurrency': And(int, lambda n: 1 <=n <= 999999),
Optional('maxExecDuration'): Regex(r'^[1-9][0-9]*[s|m|h|d]$'),
Optional('maxTrialNum'): And(int, lambda x: 1 <= x <= 99999),
'trainingServicePlatform': And(str, lambda x: x in ['remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller']),
Optional('searchSpacePath'): os.path.exists,
Optional('multiPhase'): bool,
Optional('multiThread'): bool,
Optional('nniManagerIp'): str,
Optional('logDir'): os.path.isdir,
Optional('debug'): bool,
Optional('logLevel'): Or('trace', 'debug', 'info', 'warning', 'error', 'fatal'),
Optional('logCollection'): Or('http', 'none'),
'useAnnotation': bool,
Optional('advisor'): Or({
'builtinAdvisorName': Or('Hyperband'),
'classArgs': {
'optimize_mode': Or('maximize', 'minimize'),
Optional('R'): int,
Optional('eta'): int
'authorName': setType('authorName', str),
'experimentName': setType('experimentName', str),
Optional('description'): setType('description', str),
'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$',error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool),
Optional('nniManagerIp'): setType('nniManagerIp', str),
Optional('logDir'): And(os.path.isdir, error=SCHEMA_PATH_ERROR % 'logDir'),
Optional('debug'): setType('debug', bool),
Optional('logLevel'): setChoice('logLevel', 'trace', 'debug', 'info', 'warning', 'error', 'fatal'),
Optional('logCollection'): setChoice('logCollection', 'http', 'none'),
'useAnnotation': setType('useAnnotation', bool),
Optional('tuner'): dict,
Optional('advisor'): dict,
Optional('assessor'): dict,
Optional('localConfig'): {
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!')
}
}
tuner_schema_dict = {
('TPE', 'Anneal', 'SMAC', 'Evolution'): {
'builtinTunerName': setChoice('builtinTunerName', 'TPE', 'Anneal', 'SMAC', 'Evolution'),
Optional('classArgs'): {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
Optional('classArgs'): dict,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'builtinAdvisorName': Or('BOHB'),
'classArgs': {
'optimize_mode': Or('maximize', 'minimize'),
Optional('min_budget'): And(int, lambda x: 0 <= x <= 9999),
Optional('max_budget'): And(int, lambda x: 0 <= x <= 9999),
Optional('eta'): And(int, lambda x: 0 <= x <= 9999),
Optional('min_points_in_model'): And(int, lambda x: 0 <= x <= 9999),
Optional('top_n_percent'): And(int, lambda x: 1 <= x <= 99),
Optional('num_samples'): And(int, lambda x: 1 <= x <= 9999),
Optional('random_fraction'): And(float, lambda x: 0.0 <= x <= 9999.0),
Optional('bandwidth_factor'): And(float, lambda x: 0.0 <= x <= 9999.0),
Optional('min_bandwidth'): And(float, lambda x: 0.0 <= x <= 9999.0)
('BatchTuner', 'GridSearch', 'Random'): {
'builtinTunerName': setChoice('builtinTunerName', 'BatchTuner', 'GridSearch', 'Random'),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
Optional('classArgs'): dict,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
}),
Optional('tuner'): Or({
'builtinTunerName': Or('TPE', 'Anneal', 'SMAC', 'Evolution'),
Optional('classArgs'): {
'optimize_mode': Or('maximize', 'minimize')
'NetworkMorphism': {
'builtinTunerName': 'NetworkMorphism',
'classArgs': {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('task'): setChoice('task', 'cv','nlp','common'),
Optional('input_width'): setType('input_width', int),
Optional('input_channel'): setType('input_channel', int),
Optional('n_output_node'): setType('n_output_node', int),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('includeIntermediateResults'): bool,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'builtinTunerName': Or('BatchTuner', 'GridSearch', 'Random'),
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'builtinTunerName': 'NetworkMorphism',
'classArgs': {
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('task'): And(str, lambda x: x in ['cv','nlp','common']),
Optional('input_width'): int,
Optional('input_channel'): int,
Optional('n_output_node'): int,
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'builtinTunerName': 'MetisTuner',
'classArgs': {
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('no_resampling'): bool,
Optional('no_candidates'): bool,
Optional('selection_num_starting_points'): int,
Optional('cold_start_num'): int,
'MetisTuner': {
'builtinTunerName': 'MetisTuner',
'classArgs': {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('no_resampling'): setType('no_resampling', bool),
Optional('no_candidates'): setType('no_candidates', bool),
Optional('selection_num_starting_points'): setType('selection_num_starting_points', int),
Optional('cold_start_num'): setType('cold_start_num', int),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'customized': {
'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str),
'className': setType('className', str),
Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}
}
advisor_schema_dict = {
'Hyperband':{
'builtinAdvisorName': Or('Hyperband'),
'classArgs': {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('R'): setType('R', int),
Optional('eta'): setType('eta', int)
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
Optional('classArgs'): dict,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
}),
Optional('assessor'): Or({
'builtinAssessorName': lambda x: x in ['Medianstop'],
Optional('classArgs'): {
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('start_step'): And(int, lambda x: 0 <= x <= 9999)
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999)
},{
'builtinAssessorName': lambda x: x in ['Curvefitting'],
Optional('classArgs'): {
'epoch_num': And(int, lambda x: 0 <= x <= 9999),
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('start_step'): And(int, lambda x: 0 <= x <= 9999),
Optional('threshold'): And(float, lambda x: 0.0 <= x <= 9999.0),
Optional('gap'): And(int, lambda x: 1 <= x <= 9999)
'BOHB':{
'builtinAdvisorName': Or('BOHB'),
'classArgs': {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('min_budget'): setNumberRange('min_budget', int, 0, 9999),
Optional('max_budget'): setNumberRange('max_budget', int, 0, 9999),
Optional('eta'):setNumberRange('eta', int, 0, 9999),
Optional('min_points_in_model'): setNumberRange('min_points_in_model', int, 0, 9999),
Optional('top_n_percent'): setNumberRange('top_n_percent', int, 1, 99),
Optional('num_samples'): setNumberRange('num_samples', int, 1, 9999),
Optional('random_fraction'): setNumberRange('random_fraction', float, 0, 9999),
Optional('bandwidth_factor'): setNumberRange('bandwidth_factor', float, 0, 9999),
Optional('min_bandwidth'): setNumberRange('min_bandwidth', float, 0, 9999),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999)
},{
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
Optional('classArgs'): dict,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
}),
Optional('localConfig'): {
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0))
'customized':{
'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str),
'className': setType('className', str),
Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}
}
assessor_schema_dict = {
'Medianstop': {
'builtinAssessorName': 'Medianstop',
Optional('classArgs'): {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('start_step'): setNumberRange('start_step', int, 0, 9999),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'Curvefitting': {
'builtinAssessorName': 'Curvefitting',
Optional('classArgs'): {
'epoch_num': setNumberRange('epoch_num', int, 0, 9999),
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('start_step'): setNumberRange('start_step', int, 0, 9999),
Optional('threshold'): setNumberRange('threshold', float, 0, 9999),
Optional('gap'): setNumberRange('gap', int, 1, 9999),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'customized': {
'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str),
'className': setType('className', str),
Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999)
}
}
common_trial_schema = {
'trial':{
'command': str,
'codeDir': os.path.exists,
'gpuNum': And(int, lambda x: 0 <= x <= 99999)
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999)
}
}
pai_trial_schema = {
'trial':{
'command': str,
'codeDir': os.path.exists,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str,
Optional('shmMB'): int,
Optional('dataDir'): Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
Optional('outputDir'): Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
Optional('virtualCluster'): str
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str),
Optional('shmMB'): setType('shmMB', int),
Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('virtualCluster'): setType('virtualCluster', str),
}
}
pai_config_schema = {
'paiConfig':{
'userName': str,
'passWord': str,
'host': str
}
'paiConfig':{
'userName': setType('userName', str),
'passWord': setType('passWord', str),
'host': setType('host', str)
}
}
kubeflow_trial_schema = {
'trial':{
'codeDir': os.path.exists,
'codeDir': setPathCheck('codeDir'),
Optional('ps'): {
'replicas': int,
'command': str,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str
'replicas': setType('replicas', int),
'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str)
},
Optional('master'): {
'replicas': int,
'command': str,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str
'replicas': setType('replicas', int),
'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str)
},
Optional('worker'):{
'replicas': int,
'command': str,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str
'replicas': setType('replicas', int),
'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str)
}
}
}
kubeflow_config_schema = {
'kubeflowConfig':Or({
'operator': Or('tf-operator', 'pytorch-operator'),
'apiVersion': str,
Optional('storage'): Or('nfs', 'azureStorage'),
'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
'apiVersion': setType('apiVersion', str),
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
'nfs': {
'server': str,
'path': str
'server': setType('server', str),
'path': setType('path', str)
}
},{
'operator': Or('tf-operator', 'pytorch-operator'),
'apiVersion': str,
Optional('storage'): Or('nfs', 'azureStorage'),
'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
'apiVersion': setType('apiVersion', str),
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
'keyVault': {
'vaultName': Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
'name': Regex('([0-9]|[a-z]|[A-Z]|-){1,127}')
'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
},
'azureStorage': {
'accountName': Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
'azureShare': Regex('([0-9]|[a-z]|[A-Z]|-){3,63}')
'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\
error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\
error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
}
})
}
frameworkcontroller_trial_schema = {
'trial':{
'codeDir': os.path.exists,
'codeDir': setPathCheck('codeDir'),
'taskRoles': [{
'name': str,
'taskNum': int,
'name': setType('name', str),
'taskNum': setType('taskNum', int),
'frameworkAttemptCompletionPolicy': {
'minFailedTaskCount': int,
'minSucceededTaskCount': int
'minFailedTaskCount': setType('minFailedTaskCount', int),
'minSucceededTaskCount': setType('minSucceededTaskCount', int),
},
'command': str,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str
'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str)
}]
}
}
frameworkcontroller_config_schema = {
'frameworkcontrollerConfig':Or({
Optional('storage'): Or('nfs', 'azureStorage'),
Optional('serviceAccountName'): str,
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
Optional('serviceAccountName'): setType('serviceAccountName', str),
'nfs': {
'server': str,
'path': str
'server': setType('server', str),
'path': setType('path', str)
}
},{
Optional('storage'): Or('nfs', 'azureStorage'),
Optional('serviceAccountName'): str,
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
Optional('serviceAccountName'): setType('serviceAccountName', str),
'keyVault': {
'vaultName': Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
'name': Regex('([0-9]|[a-z]|[A-Z]|-){1,127}')
'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
},
'azureStorage': {
'accountName': Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
'azureShare': Regex('([0-9]|[a-z]|[A-Z]|-){3,63}')
'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\
error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\
error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
}
})
}
machine_list_schima = {
Optional('machineList'):[Or({
'ip': str,
Optional('port'): And(int, lambda x: 0 < x < 65535),
'username': str,
'passwd': str,
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0))
'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535),
'username': setType('username', str),
'passwd': setType('passwd', str),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!')
},{
'ip': str,
Optional('port'): And(int, lambda x: 0 < x < 65535),
'username': str,
'sshKeyPath': os.path.exists,
Optional('passphrase'): str,
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0))
'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535),
'username': setType('username', str),
'sshKeyPath': setPathCheck('sshKeyPath'),
Optional('passphrase'): setType('passphrase', str),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!')
})]
}
......
......@@ -85,3 +85,9 @@ COLOR_RED_FORMAT = '\033[1;31;31m%s\033[0m'
COLOR_GREEN_FORMAT = '\033[1;32;32m%s\033[0m'
COLOR_YELLOW_FORMAT = '\033[1;33;33m%s\033[0m'
SCHEMA_TYPE_ERROR = '%s should be %s type!'
SCHEMA_RANGE_ERROR = '%s should be in range of %s!'
SCHEMA_PATH_ERROR = '%s path not exist!'
......@@ -20,8 +20,11 @@
import os
import json
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA, FRAMEWORKCONTROLLER_CONFIG_SCHEMA
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA, FRAMEWORKCONTROLLER_CONFIG_SCHEMA, \
tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from schema import SchemaMissingKeyError, SchemaForbiddenKeyError, SchemaUnexpectedTypeError, SchemaWrongKeyError, SchemaError
from .common_utils import get_json_content, print_error, print_warning, print_normal
from schema import Schema, And, Use, Optional, Regex, Or
def expand_path(experiment_config, key):
'''Change '~' to user home directory'''
......@@ -134,21 +137,49 @@ def validate_common_content(experiment_config):
'kubeflow': KUBEFLOW_CONFIG_SCHEMA,
'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA
}
separate_schema_dict = {
'tuner': tuner_schema_dict,
'advisor': advisor_schema_dict,
'assessor': assessor_schema_dict
}
separate_builtInName_dict = {
'tuner': 'builtinTunerName',
'advisor': 'builtinAdvisorName',
'assessor': 'builtinAssessorName'
}
try:
schema_dict.get(experiment_config['trainingServicePlatform']).validate(experiment_config)
#set default value
if experiment_config.get('maxExecDuration') is None:
experiment_config['maxExecDuration'] = '999d'
if experiment_config.get('maxTrialNum') is None:
experiment_config['maxTrialNum'] = 99999
if experiment_config['trainingServicePlatform'] == 'remote':
for index in range(len(experiment_config['machineList'])):
if experiment_config['machineList'][index].get('port') is None:
experiment_config['machineList'][index]['port'] = 22
except Exception as exception:
print_error('Your config file is not correct, please check your config file content!\n%s' % exception)
for separate_key in separate_schema_dict.keys():
if experiment_config.get(separate_key):
if experiment_config[separate_key].get(separate_builtInName_dict[separate_key]):
validate = False
for key in separate_schema_dict[separate_key].keys():
if key.__contains__(experiment_config[separate_key][separate_builtInName_dict[separate_key]]):
Schema({**separate_schema_dict[separate_key][key]}).validate(experiment_config[separate_key])
validate = True
break
if not validate:
print_error('%s %s error!' % (separate_key, separate_builtInName_dict[separate_key]))
exit(1)
else:
Schema({**separate_schema_dict[separate_key]['customized']}).validate(experiment_config[separate_key])
except SchemaError as error:
print_error('Your config file is not correct, please check your config file content!')
if error.__str__().__contains__('Wrong key'):
print_error(' '.join(error.__str__().split()[:3]))
else:
print_error(error)
exit(1)
#set default value
if experiment_config.get('maxExecDuration') is None:
experiment_config['maxExecDuration'] = '999d'
if experiment_config.get('maxTrialNum') is None:
experiment_config['maxTrialNum'] = 99999
if experiment_config['trainingServicePlatform'] == 'remote':
for index in range(len(experiment_config['machineList'])):
if experiment_config['machineList'][index].get('port') is None:
experiment_config['machineList'][index]['port'] = 22
def validate_customized_file(experiment_config, spec_key):
'''
......@@ -230,6 +261,9 @@ def validate_all_content(experiment_config, config_path):
validate_pai_trial_conifg(experiment_config)
experiment_config['maxExecDuration'] = parse_time(experiment_config['maxExecDuration'])
if experiment_config.get('advisor'):
if experiment_config.get('assessor') or experiment_config.get('tuner'):
print_error('advisor could not be set with assessor or tuner simultaneously!')
exit(1)
parse_advisor_content(experiment_config)
validate_annotation_content(experiment_config, 'advisor', 'builtinAdvisorName')
else:
......
......@@ -8,7 +8,7 @@ setuptools.setup(
python_requires = '>=3.5',
install_requires = [
'requests',
'pyyaml',
'ruamel.yaml',
'psutil',
'astor',
'schema',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment