Unverified Commit 0663218b authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #163 from Microsoft/master

merge master
parents 6c9360a5 cf983800
......@@ -18,17 +18,20 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
import sys
import json
import yaml
import ruamel.yaml as yaml
import psutil
import socket
from pathlib import Path
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO, COLOR_RED_FORMAT, COLOR_YELLOW_FORMAT
def get_yml_content(file_path):
'''Load yaml file content'''
try:
with open(file_path, 'r') as file:
return yaml.load(file)
return yaml.load(file, Loader=yaml.Loader)
except TypeError as err:
print('Error: ', err)
return None
......@@ -71,3 +74,15 @@ def detect_port(port):
return True
except:
return False
def get_user():
if sys.platform =='win32':
return os.environ['USERNAME']
else:
return os.environ['USER']
def get_python_dir(sitepackages_path):
if sys.platform == "win32":
return str(Path(sitepackages_path))
else:
return str(Path(sitepackages_path).parents[2])
\ No newline at end of file
......@@ -20,267 +20,310 @@
import os
from schema import Schema, And, Use, Optional, Regex, Or
from .constants import SCHEMA_TYPE_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_PATH_ERROR
def setType(key, type):
'''check key type'''
return And(type, error=SCHEMA_TYPE_ERROR % (key, type.__name__))
def setChoice(key, *args):
'''check choice'''
return And(lambda n: n in args, error=SCHEMA_RANGE_ERROR % (key, str(args)))
def setNumberRange(key, keyType, start, end):
'''check number range'''
return And(
And(keyType, error=SCHEMA_TYPE_ERROR % (key, keyType.__name__)),
And(lambda n: start <= n <= end, error=SCHEMA_RANGE_ERROR % (key, '(%s,%s)' % (start, end))),
)
def setPathCheck(key):
'''check if path exist'''
return And(os.path.exists, error=SCHEMA_PATH_ERROR % key)
common_schema = {
'authorName': str,
'experimentName': str,
Optional('description'): str,
'trialConcurrency': And(int, lambda n: 1 <=n <= 999999),
Optional('maxExecDuration'): Regex(r'^[1-9][0-9]*[s|m|h|d]$'),
Optional('maxTrialNum'): And(int, lambda x: 1 <= x <= 99999),
'trainingServicePlatform': And(str, lambda x: x in ['remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller']),
Optional('searchSpacePath'): os.path.exists,
Optional('multiPhase'): bool,
Optional('multiThread'): bool,
Optional('nniManagerIp'): str,
Optional('logDir'): os.path.isdir,
Optional('debug'): bool,
Optional('logLevel'): Or('trace', 'debug', 'info', 'warning', 'error', 'fatal'),
Optional('logCollection'): Or('http', 'none'),
'useAnnotation': bool,
Optional('advisor'): Or({
'builtinAdvisorName': Or('Hyperband'),
'classArgs': {
'optimize_mode': Or('maximize', 'minimize'),
Optional('R'): int,
Optional('eta'): int
'authorName': setType('authorName', str),
'experimentName': setType('experimentName', str),
Optional('description'): setType('description', str),
'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$',error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool),
Optional('nniManagerIp'): setType('nniManagerIp', str),
Optional('logDir'): And(os.path.isdir, error=SCHEMA_PATH_ERROR % 'logDir'),
Optional('debug'): setType('debug', bool),
Optional('logLevel'): setChoice('logLevel', 'trace', 'debug', 'info', 'warning', 'error', 'fatal'),
Optional('logCollection'): setChoice('logCollection', 'http', 'none'),
'useAnnotation': setType('useAnnotation', bool),
Optional('tuner'): dict,
Optional('advisor'): dict,
Optional('assessor'): dict,
Optional('localConfig'): {
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!')
}
}
tuner_schema_dict = {
('TPE', 'Anneal', 'SMAC', 'Evolution'): {
'builtinTunerName': setChoice('builtinTunerName', 'TPE', 'Anneal', 'SMAC', 'Evolution'),
Optional('classArgs'): {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
Optional('classArgs'): dict,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'builtinAdvisorName': Or('BOHB'),
'classArgs': {
'optimize_mode': Or('maximize', 'minimize'),
Optional('min_budget'): And(int, lambda x: 0 <= x <= 9999),
Optional('max_budget'): And(int, lambda x: 0 <= x <= 9999),
Optional('eta'): And(int, lambda x: 0 <= x <= 9999),
Optional('min_points_in_model'): And(int, lambda x: 0 <= x <= 9999),
Optional('top_n_percent'): And(int, lambda x: 1 <= x <= 99),
Optional('num_samples'): And(int, lambda x: 1 <= x <= 9999),
Optional('random_fraction'): And(float, lambda x: 0.0 <= x <= 9999.0),
Optional('bandwidth_factor'): And(float, lambda x: 0.0 <= x <= 9999.0),
Optional('min_bandwidth'): And(float, lambda x: 0.0 <= x <= 9999.0)
('BatchTuner', 'GridSearch', 'Random'): {
'builtinTunerName': setChoice('builtinTunerName', 'BatchTuner', 'GridSearch', 'Random'),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
Optional('classArgs'): dict,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
}),
Optional('tuner'): Or({
'builtinTunerName': Or('TPE', 'Anneal', 'SMAC', 'Evolution'),
Optional('classArgs'): {
'optimize_mode': Or('maximize', 'minimize')
'NetworkMorphism': {
'builtinTunerName': 'NetworkMorphism',
'classArgs': {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('task'): setChoice('task', 'cv','nlp','common'),
Optional('input_width'): setType('input_width', int),
Optional('input_channel'): setType('input_channel', int),
Optional('n_output_node'): setType('n_output_node', int),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('includeIntermediateResults'): bool,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'builtinTunerName': Or('BatchTuner', 'GridSearch', 'Random'),
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'builtinTunerName': 'NetworkMorphism',
'classArgs': {
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('task'): And(str, lambda x: x in ['cv','nlp','common']),
Optional('input_width'): int,
Optional('input_channel'): int,
Optional('n_output_node'): int,
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'builtinTunerName': 'MetisTuner',
'classArgs': {
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('no_resampling'): bool,
Optional('no_candidates'): bool,
Optional('selection_num_starting_points'): int,
Optional('cold_start_num'): int,
'MetisTuner': {
'builtinTunerName': 'MetisTuner',
'classArgs': {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('no_resampling'): setType('no_resampling', bool),
Optional('no_candidates'): setType('no_candidates', bool),
Optional('selection_num_starting_points'): setType('selection_num_starting_points', int),
Optional('cold_start_num'): setType('cold_start_num', int),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'customized': {
'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str),
'className': setType('className', str),
Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}
}
advisor_schema_dict = {
'Hyperband':{
'builtinAdvisorName': Or('Hyperband'),
'classArgs': {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('R'): setType('R', int),
Optional('eta'): setType('eta', int)
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
},{
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
Optional('classArgs'): dict,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
}),
Optional('assessor'): Or({
'builtinAssessorName': lambda x: x in ['Medianstop'],
Optional('classArgs'): {
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('start_step'): And(int, lambda x: 0 <= x <= 9999)
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999)
},{
'builtinAssessorName': lambda x: x in ['Curvefitting'],
Optional('classArgs'): {
'epoch_num': And(int, lambda x: 0 <= x <= 9999),
Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('start_step'): And(int, lambda x: 0 <= x <= 9999),
Optional('threshold'): And(float, lambda x: 0.0 <= x <= 9999.0),
Optional('gap'): And(int, lambda x: 1 <= x <= 9999)
'BOHB':{
'builtinAdvisorName': Or('BOHB'),
'classArgs': {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('min_budget'): setNumberRange('min_budget', int, 0, 9999),
Optional('max_budget'): setNumberRange('max_budget', int, 0, 9999),
Optional('eta'):setNumberRange('eta', int, 0, 9999),
Optional('min_points_in_model'): setNumberRange('min_points_in_model', int, 0, 9999),
Optional('top_n_percent'): setNumberRange('top_n_percent', int, 1, 99),
Optional('num_samples'): setNumberRange('num_samples', int, 1, 9999),
Optional('random_fraction'): setNumberRange('random_fraction', float, 0, 9999),
Optional('bandwidth_factor'): setNumberRange('bandwidth_factor', float, 0, 9999),
Optional('min_bandwidth'): setNumberRange('min_bandwidth', float, 0, 9999),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999)
},{
'codeDir': os.path.exists,
'classFileName': str,
'className': str,
Optional('classArgs'): dict,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
}),
Optional('localConfig'): {
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0))
'customized':{
'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str),
'className': setType('className', str),
Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}
}
assessor_schema_dict = {
'Medianstop': {
'builtinAssessorName': 'Medianstop',
Optional('classArgs'): {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('start_step'): setNumberRange('start_step', int, 0, 9999),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'Curvefitting': {
'builtinAssessorName': 'Curvefitting',
Optional('classArgs'): {
'epoch_num': setNumberRange('epoch_num', int, 0, 9999),
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('start_step'): setNumberRange('start_step', int, 0, 9999),
Optional('threshold'): setNumberRange('threshold', float, 0, 9999),
Optional('gap'): setNumberRange('gap', int, 1, 9999),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'customized': {
'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str),
'className': setType('className', str),
Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999)
}
}
common_trial_schema = {
'trial':{
'command': str,
'codeDir': os.path.exists,
'gpuNum': And(int, lambda x: 0 <= x <= 99999)
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999)
}
}
pai_trial_schema = {
'trial':{
'command': str,
'codeDir': os.path.exists,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str,
Optional('shmMB'): int,
Optional('dataDir'): Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
Optional('outputDir'): Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),
Optional('virtualCluster'): str
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str),
Optional('shmMB'): setType('shmMB', int),
Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('virtualCluster'): setType('virtualCluster', str),
}
}
pai_config_schema = {
'paiConfig':{
'userName': str,
'passWord': str,
'host': str
}
'paiConfig':{
'userName': setType('userName', str),
'passWord': setType('passWord', str),
'host': setType('host', str)
}
}
kubeflow_trial_schema = {
'trial':{
'codeDir': os.path.exists,
'codeDir': setPathCheck('codeDir'),
Optional('ps'): {
'replicas': int,
'command': str,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str
'replicas': setType('replicas', int),
'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str)
},
Optional('master'): {
'replicas': int,
'command': str,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str
'replicas': setType('replicas', int),
'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str)
},
Optional('worker'):{
'replicas': int,
'command': str,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str
'replicas': setType('replicas', int),
'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str)
}
}
}
kubeflow_config_schema = {
'kubeflowConfig':Or({
'operator': Or('tf-operator', 'pytorch-operator'),
'apiVersion': str,
Optional('storage'): Or('nfs', 'azureStorage'),
'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
'apiVersion': setType('apiVersion', str),
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
'nfs': {
'server': str,
'path': str
'server': setType('server', str),
'path': setType('path', str)
}
},{
'operator': Or('tf-operator', 'pytorch-operator'),
'apiVersion': str,
Optional('storage'): Or('nfs', 'azureStorage'),
'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
'apiVersion': setType('apiVersion', str),
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
'keyVault': {
'vaultName': Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
'name': Regex('([0-9]|[a-z]|[A-Z]|-){1,127}')
'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
},
'azureStorage': {
'accountName': Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
'azureShare': Regex('([0-9]|[a-z]|[A-Z]|-){3,63}')
'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\
error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\
error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
}
})
}
frameworkcontroller_trial_schema = {
'trial':{
'codeDir': os.path.exists,
'codeDir': setPathCheck('codeDir'),
'taskRoles': [{
'name': str,
'taskNum': int,
'name': setType('name', str),
'taskNum': setType('taskNum', int),
'frameworkAttemptCompletionPolicy': {
'minFailedTaskCount': int,
'minSucceededTaskCount': int
'minFailedTaskCount': setType('minFailedTaskCount', int),
'minSucceededTaskCount': setType('minSucceededTaskCount', int),
},
'command': str,
'gpuNum': And(int, lambda x: 0 <= x <= 99999),
'cpuNum': And(int, lambda x: 0 <= x <= 99999),
'memoryMB': int,
'image': str
'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str)
}]
}
}
frameworkcontroller_config_schema = {
'frameworkcontrollerConfig':Or({
Optional('storage'): Or('nfs', 'azureStorage'),
Optional('serviceAccountName'): str,
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
Optional('serviceAccountName'): setType('serviceAccountName', str),
'nfs': {
'server': str,
'path': str
'server': setType('server', str),
'path': setType('path', str)
}
},{
Optional('storage'): Or('nfs', 'azureStorage'),
Optional('serviceAccountName'): str,
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
Optional('serviceAccountName'): setType('serviceAccountName', str),
'keyVault': {
'vaultName': Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),
'name': Regex('([0-9]|[a-z]|[A-Z]|-){1,127}')
'vaultName': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
error='ERROR: vaultName format error, vaultName support using (0-9|a-z|A-Z|-)'),
'name': And(Regex('([0-9]|[a-z]|[A-Z]|-){1,127}'),\
error='ERROR: name format error, name support using (0-9|a-z|A-Z|-)')
},
'azureStorage': {
'accountName': Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),
'azureShare': Regex('([0-9]|[a-z]|[A-Z]|-){3,63}')
'accountName': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,31}'),\
error='ERROR: accountName format error, accountName support using (0-9|a-z|A-Z|-)'),
'azureShare': And(Regex('([0-9]|[a-z]|[A-Z]|-){3,63}'),\
error='ERROR: azureShare format error, azureShare support using (0-9|a-z|A-Z|-)')
}
})
}
machine_list_schima = {
Optional('machineList'):[Or({
'ip': str,
Optional('port'): And(int, lambda x: 0 < x < 65535),
'username': str,
'passwd': str,
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0))
'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535),
'username': setType('username', str),
'passwd': setType('passwd', str),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!')
},{
'ip': str,
Optional('port'): And(int, lambda x: 0 < x < 65535),
'username': str,
'sshKeyPath': os.path.exists,
Optional('passphrase'): str,
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0))
'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535),
'username': setType('username', str),
'sshKeyPath': setPathCheck('sshKeyPath'),
Optional('passphrase'): setType('passphrase', str),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!')
})]
}
......
......@@ -19,8 +19,9 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
from colorama import Fore
NNICTL_HOME_DIR = os.path.join(os.environ['HOME'], '.local', 'nnictl')
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl')
ERROR_INFO = 'ERROR: %s'
......@@ -32,7 +33,7 @@ DEFAULT_REST_PORT = 8080
REST_TIME_OUT = 20
EXPERIMENT_SUCCESS_INFO = '\033[1;32;32mSuccessfully started experiment!\n\033[0m' \
EXPERIMENT_SUCCESS_INFO = Fore.GREEN + 'Successfully started experiment!\n' + Fore.RESET + \
'-----------------------------------------------------------------------\n' \
'The experiment id is %s\n'\
'The Web UI urls are: %s\n' \
......@@ -80,8 +81,28 @@ PACKAGE_REQUIREMENTS = {
'BOHB': 'bohb_advisor'
}
COLOR_RED_FORMAT = '\033[1;31;31m%s\033[0m'
TUNERS_SUPPORTING_IMPORT_DATA = {
'TPE',
'Anneal',
'GridSearch',
'MetisTuner',
'BOHB'
}
TUNERS_NO_NEED_TO_IMPORT_DATA = {
'Random',
'Batch_tuner',
'Hyperband'
}
COLOR_RED_FORMAT = Fore.RED + '%s'
COLOR_GREEN_FORMAT = Fore.GREEN + '%s'
COLOR_YELLOW_FORMAT = Fore.YELLOW + '%s'
SCHEMA_TYPE_ERROR = '%s should be %s type!'
COLOR_GREEN_FORMAT = '\033[1;32;32m%s\033[0m'
SCHEMA_RANGE_ERROR = '%s should be in range of %s!'
COLOR_YELLOW_FORMAT = '\033[1;33;33m%s\033[0m'
SCHEMA_PATH_ERROR = '%s path not exist!'
......@@ -32,12 +32,13 @@ from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, detect_process, detect_port
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, detect_process, detect_port, get_user, get_python_dir
from .constants import *
import random
import site
import time
from pathlib import Path
from .command_utils import check_output_command, kill_command
def get_log_path(config_file_name):
'''generate stdout and stderr log path'''
......@@ -49,14 +50,10 @@ def print_log_content(config_file_name):
'''print log information'''
stdout_full_path, stderr_full_path = get_log_path(config_file_name)
print_normal(' Stdout:')
stdout_cmds = ['cat', stdout_full_path]
stdout_content = check_output(stdout_cmds)
print(stdout_content.decode('utf-8'))
print(check_output_command(stdout_full_path))
print('\n\n')
print_normal(' Stderr:')
stderr_cmds = ['cat', stderr_full_path]
stderr_content = check_output(stderr_cmds)
print(stderr_content.decode('utf-8'))
print(check_output_command(stderr_full_path))
def get_nni_installation_path():
''' Find nni lib from the following locations in order
......@@ -67,7 +64,7 @@ def get_nni_installation_path():
Return None if nothing is found
'''
def _generate_installation_path(sitepackages_path):
python_dir = str(Path(sitepackages_path).parents[2])
python_dir = get_python_dir(sitepackages_path)
entry_file = os.path.join(python_dir, 'nni', 'main.js')
if os.path.isfile(entry_file):
return python_dir
......@@ -132,7 +129,11 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
log_header = LOG_HEADER % str(time_now)
stdout_file.write(log_header)
stderr_file.write(log_header)
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
return process, str(time_now)
def set_trial_config(experiment_config, port, config_file_name):
......@@ -325,12 +326,12 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT)
response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
if check_response(response):
return response
else:
_, stderr_full_path = get_log_path(config_file_name)
if response:
if response is not None:
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
print_error('Setting experiment error, error message is {}'.format(response.text))
......@@ -357,7 +358,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation
if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), os.environ['USER'], 'nni', 'annotation')
path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
if not os.path.isdir(path):
os.makedirs(path)
path = tempfile.mkdtemp(dir=path)
......@@ -380,8 +381,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_error('Restful server start failed!')
print_log_content(config_file_name)
try:
cmds = ['kill', str(rest_process.pid)]
call(cmds)
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
......@@ -395,8 +395,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
else:
print_error('Failed! Error is: {}'.format(err_msg))
try:
cmds = ['kill', str(rest_process.pid)]
call(cmds)
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
......@@ -409,8 +408,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
else:
print_error('Set local config failed!')
try:
cmds = ['kill', str(rest_process.pid)]
call(cmds)
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
......@@ -425,8 +423,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if err_msg:
print_error('Failed! Error is: {}'.format(err_msg))
try:
cmds = ['kill', str(rest_process.pid)]
call(cmds)
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
......@@ -441,8 +438,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if err_msg:
print_error('Failed! Error is: {}'.format(err_msg))
try:
cmds = ['pkill', str(rest_process.pid)]
call(cmds)
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
......@@ -457,8 +453,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if err_msg:
print_error('Failed! Error is: {}'.format(err_msg))
try:
cmds = ['pkill', str(rest_process.pid)]
call(cmds)
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
......@@ -477,8 +472,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_error('Start experiment failed!')
print_log_content(config_file_name)
try:
cmds = ['kill', str(rest_process.pid)]
call(cmds)
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
......
......@@ -20,8 +20,11 @@
import os
import json
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA, FRAMEWORKCONTROLLER_CONFIG_SCHEMA
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA, FRAMEWORKCONTROLLER_CONFIG_SCHEMA, \
tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from schema import SchemaMissingKeyError, SchemaForbiddenKeyError, SchemaUnexpectedTypeError, SchemaWrongKeyError, SchemaError
from .common_utils import get_json_content, print_error, print_warning, print_normal
from schema import Schema, And, Use, Optional, Regex, Or
def expand_path(experiment_config, key):
'''Change '~' to user home directory'''
......@@ -134,21 +137,49 @@ def validate_common_content(experiment_config):
'kubeflow': KUBEFLOW_CONFIG_SCHEMA,
'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA
}
separate_schema_dict = {
'tuner': tuner_schema_dict,
'advisor': advisor_schema_dict,
'assessor': assessor_schema_dict
}
separate_builtInName_dict = {
'tuner': 'builtinTunerName',
'advisor': 'builtinAdvisorName',
'assessor': 'builtinAssessorName'
}
try:
schema_dict.get(experiment_config['trainingServicePlatform']).validate(experiment_config)
#set default value
if experiment_config.get('maxExecDuration') is None:
experiment_config['maxExecDuration'] = '999d'
if experiment_config.get('maxTrialNum') is None:
experiment_config['maxTrialNum'] = 99999
if experiment_config['trainingServicePlatform'] == 'remote':
for index in range(len(experiment_config['machineList'])):
if experiment_config['machineList'][index].get('port') is None:
experiment_config['machineList'][index]['port'] = 22
except Exception as exception:
print_error('Your config file is not correct, please check your config file content!\n%s' % exception)
for separate_key in separate_schema_dict.keys():
if experiment_config.get(separate_key):
if experiment_config[separate_key].get(separate_builtInName_dict[separate_key]):
validate = False
for key in separate_schema_dict[separate_key].keys():
if key.__contains__(experiment_config[separate_key][separate_builtInName_dict[separate_key]]):
Schema({**separate_schema_dict[separate_key][key]}).validate(experiment_config[separate_key])
validate = True
break
if not validate:
print_error('%s %s error!' % (separate_key, separate_builtInName_dict[separate_key]))
exit(1)
else:
Schema({**separate_schema_dict[separate_key]['customized']}).validate(experiment_config[separate_key])
except SchemaError as error:
print_error('Your config file is not correct, please check your config file content!')
if error.__str__().__contains__('Wrong key'):
print_error(' '.join(error.__str__().split()[:3]))
else:
print_error(error)
exit(1)
#set default value
if experiment_config.get('maxExecDuration') is None:
experiment_config['maxExecDuration'] = '999d'
if experiment_config.get('maxTrialNum') is None:
experiment_config['maxTrialNum'] = 99999
if experiment_config['trainingServicePlatform'] == 'remote':
for index in range(len(experiment_config['machineList'])):
if experiment_config['machineList'][index].get('port') is None:
experiment_config['machineList'][index]['port'] = 22
def validate_customized_file(experiment_config, spec_key):
'''
......@@ -230,6 +261,9 @@ def validate_all_content(experiment_config, config_path):
validate_pai_trial_conifg(experiment_config)
experiment_config['maxExecDuration'] = parse_time(experiment_config['maxExecDuration'])
if experiment_config.get('advisor'):
if experiment_config.get('assessor') or experiment_config.get('tuner'):
print_error('advisor could not be set with assessor or tuner simultaneously!')
exit(1)
parse_advisor_content(experiment_config)
validate_annotation_content(experiment_config, 'advisor', 'builtinAdvisorName')
else:
......
......@@ -22,11 +22,13 @@
import argparse
import pkg_resources
from .launcher import create_experiment, resume_experiment
from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum
from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum, import_data
from .nnictl_utils import *
from .package_management import *
from .constants import *
from .tensorboard_utils import *
from colorama import init
init(autoreset=True)
if os.environ.get('COVERAGE_PROCESS_START'):
import coverage
......@@ -98,13 +100,9 @@ def parse_args():
parser_trial_ls.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_ls.set_defaults(func=trial_ls)
parser_trial_kill = parser_trial_subparsers.add_parser('kill', help='kill trial jobs')
parser_trial_kill.add_argument('id', nargs='?', help='id of the trial to be killed')
parser_trial_kill.add_argument('--experiment', '-E', required=True, dest='experiment', help='experiment id of the trial')
parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed')
parser_trial_kill.set_defaults(func=trial_kill)
parser_trial_export = parser_trial_subparsers.add_parser('export', help='export trial job results to csv')
parser_trial_export.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_export.add_argument('--file', '-f', required=True, dest='csv_path', help='target csv file path')
parser_trial_export.set_defaults(func=export_trials_data)
#parse experiment command
parser_experiment = subparsers.add_parser('experiment', help='get experiment information')
......@@ -119,6 +117,17 @@ def parse_args():
parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids')
parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments')
parser_experiment_list.set_defaults(func=experiment_list)
#import tuning data
parser_import_data = parser_experiment_subparsers.add_parser('import', help='import additional data')
parser_import_data.add_argument('id', nargs='?', help='the id of experiment')
parser_import_data.add_argument('--filename', '-f', required=True)
parser_import_data.set_defaults(func=import_data)
#export trial data
parser_trial_export = parser_experiment_subparsers.add_parser('export', help='export trial job results to csv or json')
parser_trial_export.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_export.add_argument('--type', '-t', choices=['json', 'csv'], required=True, dest='type', help='target file type')
parser_trial_export.add_argument('--filename', '-f', required=True, dest='path', help='target file path')
parser_trial_export.set_defaults(func=export_trials_data)
#TODO:finish webui function
#parse board command
......@@ -153,8 +162,8 @@ def parse_args():
parser_log_stderr.add_argument('--path', action='store_true', default=False, help='get the path of stderr file')
parser_log_stderr.set_defaults(func=log_stderr)
parser_log_trial = parser_log_subparsers.add_parser('trial', help='get trial log path')
parser_log_trial.add_argument('id', nargs='?', help='id of the trial to be found the log path')
parser_log_trial.add_argument('--experiment', '-E', dest='experiment', help='experiment id of the trial, xperiment ID of the trial, required when id is not empty.')
parser_log_trial.add_argument('id', nargs='?', help='the id of experiment')
parser_log_trial.add_argument('--trial_id', '-T', dest='trial_id', help='find trial log path by id')
parser_log_trial.set_defaults(func=log_trial)
#parse package command
......@@ -172,7 +181,7 @@ def parse_args():
parser_tensorboard_subparsers = parser_tensorboard.add_subparsers()
parser_tensorboard_start = parser_tensorboard_subparsers.add_parser('start', help='start tensorboard')
parser_tensorboard_start.add_argument('id', nargs='?', help='the id of experiment')
parser_tensorboard_start.add_argument('--trialid', dest='trialid', help='the id of trial')
parser_tensorboard_start.add_argument('--trial_id', '-T', dest='trial_id', help='the id of trial')
parser_tensorboard_start.add_argument('--port', dest='port', default=6006, help='the port to start tensorboard')
parser_tensorboard_start.set_defaults(func=start_tensorboard)
parser_tensorboard_start = parser_tensorboard_subparsers.add_parser('stop', help='stop tensorboard')
......
......@@ -24,7 +24,6 @@ import psutil
import json
import datetime
import time
from subprocess import call, check_output
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .config_utils import Config, Experiments
......@@ -32,6 +31,7 @@ from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url
from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .common_utils import print_normal, print_error, print_warning, detect_process
from .command_utils import check_output_command, kill_command
def get_experiment_time(port):
'''get the startTime and endTime of an experiment'''
......@@ -103,14 +103,11 @@ def check_experiment_id(args):
return None
else:
return running_experiment_list[0]
if hasattr(args, "experiment"):
if experiment_dict.get(args.experiment):
return args.experiment
elif hasattr(args, "id"):
if experiment_dict.get(args.id):
return args.id
print_error('Id not correct!')
return None
if experiment_dict.get(args.id):
return args.id
else:
print_error('Id not correct!')
return None
def parse_ids(args):
'''Parse the arguments for nnictl stop
......@@ -222,14 +219,12 @@ def stop_experiment(args):
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if rest_pid:
stop_rest_cmds = ['kill', str(rest_pid)]
call(stop_rest_cmds)
kill_command(rest_pid)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list:
try:
cmds = ['kill', '-9', str(tensorboard_pid)]
call(cmds)
kill_command(tensorboard_pid)
except Exception as exception:
print_error(exception)
nni_config.set_config('tensorboardPidList', [])
......@@ -306,14 +301,6 @@ def experiment_status(args):
else:
print(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
def get_log_content(file_name, cmds):
'''use cmds to read config content'''
if os.path.exists(file_name):
rest = check_output(cmds)
print(rest.decode('utf-8'))
else:
print_normal('NULL!')
def log_internal(args, filetype):
'''internal function to call get_log_content'''
file_name = get_config_filename(args)
......@@ -321,15 +308,8 @@ def log_internal(args, filetype):
file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'stdout')
else:
file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'stderr')
if args.head:
get_log_content(file_full_path, ['head', '-' + str(args.head), file_full_path])
elif args.tail:
get_log_content(file_full_path, ['tail', '-' + str(args.tail), file_full_path])
elif args.path:
print_normal('The path of stdout file is: ' + file_full_path)
else:
get_log_content(file_full_path, ['cat', file_full_path])
print(check_output_command(file_full_path, head=args.head, tail=args.tail))
def log_stdout(args):
'''get stdout log'''
log_internal(args, 'stdout')
......@@ -357,16 +337,15 @@ def log_trial(args):
else:
print_error('Restful server is not running...')
exit(1)
if args.experiment:
if args.id:
if trial_id_path_dict.get(args.id):
print('id:' + args.id + ' path:' + trial_id_path_dict[args.id])
if args.id:
if args.trial_id:
if trial_id_path_dict.get(args.trial_id):
print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
else:
print_error('trial id is not valid!')
exit(1)
else:
print_error('please specific the trial id!')
print_error("trial id list in this experiment: " + str(list(trial_id_path_dict.keys())))
exit(1)
else:
for key in trial_id_path_dict:
......@@ -509,10 +488,19 @@ def export_trials_data(args):
# dframe = pd.DataFrame.from_records([parse_trial_data(t_data) for t_data in content])
# dframe.to_csv(args.csv_path, sep='\t')
records = parse_trial_data(content)
with open(args.csv_path, 'w') as f_csv:
writer = csv.DictWriter(f_csv, set.union(*[set(r.keys()) for r in records]))
writer.writeheader()
writer.writerows(records)
if args.type == 'json':
json_records = []
for trial in records:
value = trial.pop('reward', None)
trial_id = trial.pop('id', None)
json_records.append({'parameter': trial, 'value': value, 'id': trial_id})
with open(args.path, 'w') as file:
if args.type == 'csv':
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in records]))
writer.writeheader()
writer.writerows(records)
else:
json.dump(json_records, file)
else:
print_error('Export failed...')
else:
......
......@@ -20,17 +20,18 @@
import nni
import os
import sys
from subprocess import call
from .constants import PACKAGE_REQUIREMENTS
from .common_utils import print_normal, print_error
from .command_utils import install_requirements_command
def process_install(package_name):
if PACKAGE_REQUIREMENTS.get(package_name) is None:
print_error('{0} is not supported!' % package_name)
else:
requirements_path = os.path.join(nni.__path__[0], PACKAGE_REQUIREMENTS[package_name])
cmds = 'cd ' + requirements_path + ' && python3 -m pip install --user -r requirements.txt'
call(cmds, shell=True)
install_requirements_command(requirements_path)
def package_install(args):
'''install packages'''
......@@ -39,4 +40,4 @@ def package_install(args):
def package_show(args):
'''show all packages'''
print(' '.join(PACKAGE_REQUIREMENTS.keys()))
\ No newline at end of file
......@@ -23,39 +23,48 @@ import time
import requests
from .url_utils import check_status_url
from .constants import REST_TIME_OUT
from .common_utils import print_error
def rest_put(url, data, timeout):
def rest_put(url, data, timeout, show_error=False):
'''Call rest put method'''
try:
response = requests.put(url, headers={'Accept': 'application/json', 'Content-Type': 'application/json'},\
data=data, timeout=timeout)
return response
except Exception:
except Exception as exception:
if show_error:
print_error(exception)
return None
def rest_post(url, data, timeout):
def rest_post(url, data, timeout, show_error=False):
'''Call rest post method'''
try:
response = requests.post(url, headers={'Accept': 'application/json', 'Content-Type': 'application/json'},\
data=data, timeout=timeout)
return response
except Exception:
except Exception as exception:
if show_error:
print_error(exception)
return None
def rest_get(url, timeout):
def rest_get(url, timeout, show_error=False):
'''Call rest get method'''
try:
response = requests.get(url, timeout=timeout)
return response
except Exception:
except Exception as exception:
if show_error:
print_error(exception)
return None
def rest_delete(url, timeout):
def rest_delete(url, timeout, show_error=False):
'''Call rest delete method'''
try:
response = requests.delete(url, timeout=timeout)
return response
except Exception:
except Exception as exception:
if show_error:
print_error(exception)
return None
def check_rest_server(rest_port):
......
......@@ -21,14 +21,14 @@
import os
from .common_utils import print_error
from subprocess import call
from .command_utils import install_package_command
def check_environment():
'''check if paramiko is installed'''
try:
import paramiko
except:
cmds = 'python3 -m pip install --user paramiko'
call(cmds, shell=True)
install_package_command('paramiko')
def copy_remote_directory_to_local(sftp, remote_path, local_path):
'''copy remote directory to local machine'''
......@@ -56,4 +56,4 @@ def create_ssh_sftp_client(host_ip, port, username, password):
sftp = paramiko.SFTPClient.from_transport(conn)
return sftp
except Exception as exception:
print_error('Create ssh client error %s\n' % exception)
\ No newline at end of file
print_error('Create ssh client error %s\n' % exception)
......@@ -40,7 +40,7 @@ def parse_log_path(args, trial_content):
path_list = []
host_list = []
for trial in trial_content:
if args.trialid and args.trialid != 'all' and trial.get('id') != args.trialid:
if args.trial_id and args.trial_id != 'all' and trial.get('id') != args.trial_id:
continue
pattern = r'(?P<head>.+)://(?P<host>.+):(?P<path>.*)'
match = re.search(pattern,trial['logPath'])
......@@ -48,7 +48,7 @@ def parse_log_path(args, trial_content):
path_list.append(match.group('path'))
host_list.append(match.group('host'))
if not path_list:
print_error('Trial id %s error!' % args.trialid)
print_error('Trial id %s error!' % args.trial_id)
exit(1)
return path_list, host_list
......@@ -154,7 +154,7 @@ def start_tensorboard(args):
if not trial_content:
print_error('No trial information!')
exit(1)
if len(trial_content) > 1 and not args.trialid:
if len(trial_content) > 1 and not args.trial_id:
print_error('There are multiple trials, please set trial id!')
exit(1)
experiment_id = nni_config.get_config('experimentId')
......
......@@ -21,13 +21,13 @@
import json
import os
from .rest_utils import rest_put, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url, import_data_url
from .config_utils import Config
from .common_utils import get_json_content
from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import check_experiment_id, get_experiment_port, get_config_filename
from .launcher_utils import parse_time
from .constants import REST_TIME_OUT
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA
def validate_digit(value, start, end):
'''validate if a digit is valid'''
......@@ -39,6 +39,23 @@ def validate_file(path):
if not os.path.exists(path):
raise FileNotFoundError('%s is not a valid file path' % path)
def validate_dispatcher(args):
'''validate if the dispatcher of the experiment supports importing data'''
nni_config = Config(get_config_filename(args)).get_config('experimentConfig')
if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'):
dispatcher_name = nni_config['tuner']['builtinTunerName']
elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'):
dispatcher_name = nni_config['advisor']['builtinAdvisorName']
else: # otherwise it should be a customized one
return
if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
if dispatcher_name in TUNERS_NO_NEED_TO_IMPORT_DATA:
print_warning("There is no need to import data for %s" % dispatcher_name)
exit(0)
else:
print_error("%s does not support importing addtional data" % dispatcher_name)
exit(1)
def load_search_space(path):
'''load search space content'''
content = json.dumps(get_json_content(path))
......@@ -71,7 +88,7 @@ def update_experiment_profile(args, key, value):
if response and check_response(response):
return response
else:
print('ERROR: restful server is not running...')
print_error('Restful server is not running...')
return None
def update_searchspace(args):
......@@ -80,18 +97,19 @@ def update_searchspace(args):
args.port = get_experiment_port(args)
if args.port is not None:
if update_experiment_profile(args, 'searchSpace', content):
print('INFO: update %s success!' % 'searchSpace')
print_normal('Update %s success!' % 'searchSpace')
else:
print('ERROR: update %s failed!' % 'searchSpace')
print_error('Update %s failed!' % 'searchSpace')
def update_concurrency(args):
validate_digit(args.value, 1, 1000)
args.port = get_experiment_port(args)
if args.port is not None:
if update_experiment_profile(args, 'trialConcurrency', int(args.value)):
print('INFO: update %s success!' % 'concurrency')
print_normal('Update %s success!' % 'concurrency')
else:
print('ERROR: update %s failed!' % 'concurrency')
print_error('Update %s failed!' % 'concurrency')
def update_duration(args):
#parse time, change time unit to seconds
......@@ -99,13 +117,38 @@ def update_duration(args):
args.port = get_experiment_port(args)
if args.port is not None:
if update_experiment_profile(args, 'maxExecDuration', int(args.value)):
print('INFO: update %s success!' % 'duration')
print_normal('Update %s success!' % 'duration')
else:
print('ERROR: update %s failed!' % 'duration')
print_error('Update %s failed!' % 'duration')
def update_trialnum(args):
validate_digit(args.value, 1, 999999999)
if update_experiment_profile(args, 'maxTrialNum', int(args.value)):
print('INFO: update %s success!' % 'trialnum')
print_normal('Update %s success!' % 'trialnum')
else:
print('ERROR: update %s failed!' % 'trialnum')
\ No newline at end of file
print_error('Update %s failed!' % 'trialnum')
def import_data(args):
'''import additional data to the experiment'''
validate_file(args.filename)
validate_dispatcher(args)
content = load_search_space(args.filename)
args.port = get_experiment_port(args)
if args.port is not None:
if import_data_to_restful_server(args, content):
print_normal('Import data success!')
else:
print_error('Import data failed!')
def import_data_to_restful_server(args, content):
'''call restful server to import data to the experiment'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
if response and check_response(response):
return response
else:
print_error('Restful server is not running...')
return None
......@@ -29,6 +29,8 @@ EXPERIMENT_API = '/experiment'
CLUSTER_METADATA_API = '/experiment/cluster-metadata'
IMPORT_DATA_API = '/experiment/import-data'
CHECK_STATUS_API = '/check-status'
TRIAL_JOBS_API = '/trial-jobs'
......@@ -46,6 +48,11 @@ def cluster_metadata_url(port):
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CLUSTER_METADATA_API)
def import_data_url(port):
'''get import_data_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, IMPORT_DATA_API)
def experiment_url(port):
'''get experiment_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPERIMENT_API)
......
......@@ -25,6 +25,9 @@ import time
from xml.dom import minidom
def check_ready_to_run():
#TODO check process in windows
if sys.platform == 'win32':
return True
pgrep_output =subprocess.check_output('pgrep -fx \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True)
pidList = []
for pid in pgrep_output.splitlines():
......
......@@ -8,11 +8,12 @@ setuptools.setup(
python_requires = '>=3.5',
install_requires = [
'requests',
'pyyaml',
'ruamel.yaml',
'psutil',
'astor',
'schema',
'PythonWebHDFS'
'PythonWebHDFS',
'colorama'
],
author = 'Microsoft NNI Team',
......
$NNI_DEPENDENCY_FOLDER = "C:\tmp\$env:USERNAME"
$env:PYTHONIOENCODING = "UTF-8"
if($env:VIRTUAL_ENV){
$NNI_PYTHON3 = $env:VIRTUAL_ENV + "\Scripts"
$NNI_PKG_FOLDER = $env:VIRTUAL_ENV + "\nni"
Remove-Item "$NNI_PYTHON3\node.exe" -Force
}
else{
$NNI_PYTHON3 = $(python -c 'import site; from pathlib import Path; print(Path(site.getsitepackages()[0]))')
$NNI_PKG_FOLDER = $NNI_PYTHON3 + "\nni"
Remove-Item "$NNI_PYTHON3\Scripts\node.exe" -Force
}
$PIP_UNINSTALL = """$NNI_PYTHON3\python"" -m pip uninstall -y "
$NNI_NODE_FOLDER = $NNI_DEPENDENCY_FOLDER+"\nni-node"
$NNI_YARN_FOLDER = $NNI_DEPENDENCY_FOLDER+"\nni-yarn"
# uninstall
Remove-Item $NNI_PKG_FOLDER -Recurse -Force
cmd /C $PIP_UNINSTALL "nni"
# clean
Remove-Item "src/nni_manager/dist" -Recurse -Force
Remove-Item "src/nni_manager/node_modules" -Recurse -Force
Remove-Item "src/webui/build" -Recurse -Force
Remove-Item "src/webui/node_modules" -Recurse -Force
Remove-Item $NNI_YARN_FOLDER -Recurse -Force
Remove-Item $NNI_NODE_FOLDER -Recurse -Force
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment