"docs/en_US/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "a6467ad88d1090543b842804d28e7f162b1f1c02"
Commit 252f36f8 authored by Deshui Yu's avatar Deshui Yu
Browse files

NNI dogfood version 1

parent 781cea26
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import json
import yaml
import psutil
from .constants import ERROR_INFO, NORMAL_INFO
def get_yml_content(file_path):
'''Load yaml file content'''
try:
with open(file_path, 'r') as file:
return yaml.load(file)
except TypeError as err:
print('Error: ', err)
return None
def get_json_content(file_path):
'''Load json file content'''
try:
with open(file_path, 'r') as file:
return json.load(file)
except TypeError as err:
print('Error: ', err)
return None
def print_error(content):
'''Print error information to screen'''
print(ERROR_INFO % content)
def print_normal(content):
'''Print error information to screen'''
print(NORMAL_INFO % content)
def detect_process(pid):
'''Detect if a process is alive'''
try:
process = psutil.Process(pid)
return process.is_running()
except:
return False
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
import json
import shutil
from .constants import METADATA_DIR, METADATA_FULL_PATH
class Config:
'''a util class to load and save config'''
def __init__(self):
os.makedirs(METADATA_DIR, exist_ok=True)
self.config_file = METADATA_FULL_PATH
self.config = self.read_file()
def get_all_config(self):
'''get all of config values'''
return json.dumps(self.config)
def set_config(self, key, value):
'''set {key:value} paris to self.config'''
self.config = self.read_file()
self.config[key] = value
self.write_file()
def get_config(self, key):
'''get a value according to key'''
return self.config.get(key)
def copy_metadata_to_new_path(self, path):
'''copy metadata to a new path'''
if not os.path.exists(path):
os.mkdir(path)
shutil.copy(self.config_file, path)
def write_file(self):
'''save config to local file'''
if self.config:
try:
with open(self.config_file, 'w') as file:
json.dump(self.config, file)
except IOError as error:
print('Error:', error)
return
def read_file(self):
'''load config from local file'''
if os.path.exists(self.config_file):
try:
with open(self.config_file, 'r') as file:
return json.load(file)
except ValueError:
return {}
return {}
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
REST_PORT = 51188
HOME_DIR = os.path.join(os.environ['HOME'], 'nni')
METADATA_DIR = os.path.join(HOME_DIR, 'nnictl')
METADATA_FULL_PATH = os.path.join(METADATA_DIR, 'metadata')
LOG_DIR = os.path.join(HOME_DIR, 'nnictl', 'log')
STDOUT_FULL_PATH = os.path.join(LOG_DIR, 'stdout')
STDERR_FULL_PATH = os.path.join(LOG_DIR, 'stderr')
ERROR_INFO = 'Error: %s'
NORMAL_INFO = 'Info: %s'
WARNING_INFO = 'Waining: %s'
EXPERIMENT_SUCCESS_INFO = 'Start experiment success! The experiment id is %s, and the restful server post is %s.\n' \
'You can use these commands to get more information about this experiment:\n' \
' commands description\n' \
'1. nnictl experiment ls list all of experiments\n' \
'2. nnictl trial ls list all of trial jobs\n' \
'3. nnictl stop stop a experiment\n' \
'4. nnictl trial kill kill a trial job by id\n' \
'5. nnictl --help get help information about nnictl\n' \
'6. nnictl webui url get the url of web ui'
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import json
import os
import shutil
from subprocess import Popen, PIPE
import tempfile
from annotation import *
from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick
from .url_utils import cluster_metadata_url, experiment_url
from .config_utils import Config
from .common_utils import get_yml_content, get_json_content, print_error, print_normal
from .constants import EXPERIMENT_SUCCESS_INFO, STDOUT_FULL_PATH, STDERR_FULL_PATH, LOG_DIR, REST_PORT, ERROR_INFO, NORMAL_INFO
from .webui_utils import start_web_ui, check_web_ui
def start_rest_server(manager, port, platform, mode, experiment_id=None):
'''Run nni manager process'''
print_normal('Checking experiment...')
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
if rest_port and check_rest_server_quick(rest_port):
print_error('There is an experiment running, please stop it first...')
print_normal('You can use \'nnictl stop\' command to stop an experiment!')
exit(0)
print_normal('Starting restful server...')
cmds = [manager, '--port', str(port), '--mode', platform, '--start_mode', mode]
if mode == 'resume':
cmds += ['--experiment_id', experiment_id]
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
stdout_file = open(STDOUT_FULL_PATH, 'a+')
stderr_file = open(STDERR_FULL_PATH, 'a+')
process = Popen(cmds, stdout=stdout_file, stderr=stderr_file)
return process
def set_local_config(experiment_config, port):
'''Call setClusterMetadata (rest PUT /parameters/cluster-metadata) to pass platform and machineList"'''
request_data = dict()
request_data['codeDir'] = experiment_config['trial']['trialCodeDir']
request_data['command'] = experiment_config['trial']['trialCommand']
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
return True if response and response.status_code == 200 else False
def set_remote_config(experiment_config, port):
'''Call setClusterMetadata to pass trial'''
#set machine_list
request_data = dict()
request_data['machine_list'] = experiment_config['machineList']
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
if not response or not response.status_code == 200:
return False
#set trial_config
request_data = dict()
value_dict = dict()
value_dict['command'] = experiment_config['trial']['trialCommand']
value_dict['codeDir'] = experiment_config['trial']['trialCodeDir']
value_dict['gpuNum'] = experiment_config['trial']['trialGpuNum']
request_data['trial_config'] = value_dict
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
return True if response.status_code == 200 else False
def set_experiment(experiment_config, mode, port):
'''Call startExperiment (rest POST /experiment) with yaml file content'''
request_data = dict()
request_data['authorName'] = experiment_config['authorName']
request_data['experimentName'] = experiment_config['experimentName']
request_data['trialConcurrency'] = experiment_config['trialConcurrency']
request_data['maxExecDuration'] = experiment_config['maxExecDuration']
request_data['maxTrialNum'] = experiment_config['maxTrialNum']
request_data['searchSpace'] = experiment_config['searchSpace']
request_data['tuner'] = experiment_config['tuner']
if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor']
request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local':
request_data['clusterMetaData'].append(
{'key':'codeDir', 'value':experiment_config['trial']['trialCodeDir']})
request_data['clusterMetaData'].append(
{'key': 'command', 'value': experiment_config['trial']['trialCommand']})
else:
request_data['clusterMetaData'].append(
{'key': 'machine_list', 'value': experiment_config['machineList']})
value_dict = dict()
value_dict['command'] = experiment_config['trial']['trialCommand']
value_dict['codeDir'] = experiment_config['trial']['trialCodeDir']
value_dict['gpuNum'] = experiment_config['trial']['trialGpuNum']
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': value_dict})
response = rest_post(experiment_url(port), json.dumps(request_data), 20)
return response if response.status_code == 200 else None
def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=None):
'''follow steps to start rest server and start experiment'''
nni_config = Config()
# start rest server
rest_process = start_rest_server(args.manager, REST_PORT, experiment_config['trainingServicePlatform'], mode, experiment_id)
nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation
if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), 'nni', 'annotation')
if os.path.isdir(path):
shutil.rmtree(path)
os.makedirs(path)
expand_annotations(experiment_config['trial']['trialCodeDir'], path)
experiment_config['trial']['trialCodeDir'] = path
search_space = generate_search_space(experiment_config['trial']['trialCodeDir'])
assert search_space, ERROR_INFO % 'Generated search space is empty'
else:
search_space = get_json_content(experiment_config['searchSpacePath'])
experiment_config['searchSpace'] = json.dumps(search_space)
# check rest server
print_normal('Checking restful server...')
if check_rest_server(REST_PORT):
print_normal('Restful server start success!')
else:
print_error('Restful server start failed!')
try:
rest_process.kill()
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(0)
# set remote config
if experiment_config['trainingServicePlatform'] == 'remote':
print_normal('Setting remote config...')
if set_remote_config(experiment_config, REST_PORT):
print_normal('Success!')
else:
print_error('Failed!')
try:
rest_process.kill()
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(0)
# set local config
if experiment_config['trainingServicePlatform'] == 'local':
print_normal('Setting local config...')
if set_local_config(experiment_config, REST_PORT):
print_normal('Success!')
else:
print_error('Failed!')
try:
rest_process.kill()
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(0)
# start a new experiment
print_normal('Starting experiment...')
response = set_experiment(experiment_config, mode, REST_PORT)
if response:
if experiment_id is None:
experiment_id = json.loads(response.text).get('experiment_id')
nni_config.set_config('experimentId', experiment_id)
else:
print_error('Failed!')
try:
rest_process.kill()
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(0)
#start webui
print_normal('Checking web ui...')
if check_web_ui():
print_error('{0} {1}'.format(' '.join(nni_config.get_config('webuiUrl')),'is being used, please stop it first!'))
print_normal('You can use \'nnictl webui stop\' to stop old web ui process...')
else:
print_normal('Starting web ui...')
webui_process = start_web_ui(webuiport)
nni_config.set_config('webuiPid', webui_process.pid)
print_normal('Starting web ui success!')
print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, REST_PORT))
def resume_experiment(args):
'''resume an experiment'''
nni_config = Config()
experiment_config = nni_config.get_config('experimentConfig')
experiment_id = nni_config.get_config('experimentId')
launch_experiment(args, experiment_config, 'resume', args.webuiport, experiment_id)
def create_experiment(args):
'''start a new experiment'''
nni_config = Config()
experiment_config = get_yml_content(args.config)
validate_all_content(experiment_config)
nni_config.set_config('experimentConfig', experiment_config)
launch_experiment(args, experiment_config, 'new', args.webuiport)
nni_config.set_config('restServerPort', REST_PORT)
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
def check_empty(experiment_config, key):
'''Check whether a key is in experiment_config and has non-empty value'''
if key not in experiment_config or experiment_config[key] is None:
raise ValueError('%s can not be empty' % key)
def check_digit(experiment_config, key, start, end):
'''Check whether a value in experiment_config is digit and in a range of [start, end]'''
if not str(experiment_config[key]).isdigit() or experiment_config[key] < start or \
experiment_config[key] > end:
raise ValueError('%s must be a digit from %s to %s' % (key, start, end))
def check_directory(experiment_config, key):
'''Check whether a value in experiment_config is a valid directory'''
if not os.path.isdir(experiment_config[key]):
raise NotADirectoryError('%s is not a valid directory' % key)
def check_file(experiment_config, key):
'''Check whether a value in experiment_config is a valid file'''
if not os.path.exists(experiment_config[key]):
raise FileNotFoundError('%s is not a valid file path' % key)
def check_choice(experiment_config, key, choice_list):
'''Check whether a value in experiment_config is in a choice list'''
if not experiment_config[key] in choice_list:
raise ValueError('%s must in [%s]' % (key, ','.join(choice_list)))
def parse_time(experiment_config, key):
'''Parse time format'''
unit = experiment_config[key][-1]
if unit not in ['s', 'm', 'h', 'd']:
raise ValueError('the unit of time could only from {s, m, h, d}')
time = experiment_config[key][:-1]
if not time.isdigit():
raise ValueError('time format error!')
parse_dict = {'s':1, 'm':60, 'h':3600, 'd':86400}
experiment_config[key] = int(time) * parse_dict[unit]
def validate_common_content(experiment_config):
'''Validate whether the common values in experiment_config is valid'''
#validate authorName
check_empty(experiment_config, 'authorName')
#validate experimentName
check_empty(experiment_config, 'experimentName')
#validate trialNoncurrency
check_empty(experiment_config, 'trialConcurrency')
check_digit(experiment_config, 'trialConcurrency', 1, 1000)
#validate execDuration
check_empty(experiment_config, 'maxExecDuration')
parse_time(experiment_config, 'maxExecDuration')
#validate maxTrialNum
check_empty(experiment_config, 'maxTrialNum')
check_digit(experiment_config, 'maxTrialNum', 1, 1000)
#validate trainingService
check_empty(experiment_config, 'trainingServicePlatform')
check_choice(experiment_config, 'trainingServicePlatform', ['local', 'remote'])
def validate_tuner_content(experiment_config):
'''Validate whether tuner in experiment_config is valid'''
tuner_algorithm_dict = {'TPE': 'nni.hyperopt_tuner --algorithm_name tpe',\
'Random': 'nni.hyperopt_tuner --algorithm_name random_search',\
'Anneal': 'nni.hyperopt_tuner --algorithm_name anneal',\
'Evolution': 'nni.evolution_tuner'}
check_empty(experiment_config, 'tuner')
#TODO: use elegent way to detect keys
if experiment_config['tuner'].get('tunerCommand') and experiment_config['tuner'].get('tunerCwd')\
and (experiment_config['tuner'].get('tunerName') or experiment_config['tuner'].get('optimizationMode'))\
or experiment_config['tuner'].get('tunerName') and experiment_config['tuner'].get('optimizationMode')\
and (experiment_config['tuner'].get('tunerCommand') or experiment_config['tuner'].get('tunerCwd')):
raise Exception('Please choose to use (tunerCommand, tunerCwd) or (tunerName, optimizationMode)')
if experiment_config['tuner'].get('tunerCommand') and experiment_config['tuner'].get('tunerCwd'):
check_directory(experiment_config['tuner'], 'tunerCwd')
experiment_config['tuner']['tunerCwd'] = os.path.abspath(experiment_config['tuner']['tunerCwd'])
elif experiment_config['tuner'].get('tunerName') and experiment_config['tuner'].get('optimizationMode'):
check_choice(experiment_config['tuner'], 'tunerName', ['TPE', 'Random', 'Anneal', 'Evolution'])
check_choice(experiment_config['tuner'], 'optimizationMode', ['Maximize', 'Minimize'])
if experiment_config['tuner']['optimizationMode'] == 'Maximize':
experiment_config['tuner']['optimizationMode'] = 'maximize'
else:
experiment_config['tuner']['optimizationMode'] = 'minimize'
experiment_config['tuner']['tunerCommand'] = 'python3 -m %s --optimize_mode %s'\
% (tuner_algorithm_dict.get(experiment_config['tuner']['tunerName']), experiment_config['tuner']['optimizationMode'])
experiment_config['tuner']['tunerCwd'] = ''
else:
raise ValueError('Please complete tuner information!')
if experiment_config['tuner'].get('tunerGpuNum'):
check_digit(experiment_config['tuner'], 'tunerGpuNum', 0, 100)
def validate_assessor_content(experiment_config):
'''Validate whether assessor in experiment_config is valid'''
assessor_algorithm_dict = {'Medianstop': 'nni.medianstop_assessor'}
if 'assessor' in experiment_config:
if experiment_config['assessor']:
if experiment_config['assessor'].get('assessorCommand') and experiment_config['assessor'].get('assessorCwd')\
and (experiment_config['assessor'].get('assessorName') or experiment_config['assessor'].get('optimizationMode'))\
or experiment_config['assessor'].get('assessorName') and experiment_config['assessor'].get('optimizationMode')\
and (experiment_config['assessor'].get('assessorCommand') or experiment_config['assessor'].get('assessorCwd')):
raise Exception('Please choose to use (assessorCommand, assessorCwd) or (assessorName, optimizationMode)')
if experiment_config['assessor'].get('assessorCommand') and experiment_config['assessor'].get('assessorCwd'):
check_empty(experiment_config['assessor'], 'assessorCommand')
check_empty(experiment_config['assessor'], 'assessorCwd')
check_directory(experiment_config['assessor'], 'assessorCwd')
experiment_config['assessor']['assessorCwd'] = os.path.abspath(experiment_config['assessor']['assessorCwd'])
if 'assessorGpuNum' in experiment_config['assessor']:
if experiment_config['assessor']['assessorGpuNum']:
check_digit(experiment_config['assessor'], 'assessorGpuNum', 0, 100)
elif experiment_config['assessor'].get('assessorName') and experiment_config['assessor'].get('optimizationMode'):
check_choice(experiment_config['assessor'], 'assessorName', ['Medianstop'])
check_choice(experiment_config['assessor'], 'optimizationMode', ['Maximize', 'Minimize'])
if experiment_config['assessor']['optimizationMode'] == 'Maximize':
experiment_config['assessor']['optimizationMode'] = 'maximize'
else:
experiment_config['assessor']['optimizationMode'] = 'minimize'
experiment_config['assessor']['assessorCommand'] = 'python3 -m %s --optimize_mode %s'\
% (assessor_algorithm_dict.get(experiment_config['assessor']['assessorName']), experiment_config['assessor']['optimizationMode'])
experiment_config['assessor']['assessorCwd'] = ''
else:
raise ValueError('Please complete assessor information!')
if experiment_config['assessor'].get('assessorGpuNum'):
check_digit(experiment_config['assessor'], 'assessorGpuNum', 0, 100)
def validate_trail_content(experiment_config):
'''Validate whether trial in experiment_config is valid'''
check_empty(experiment_config, 'trial')
check_empty(experiment_config['trial'], 'trialCommand')
check_empty(experiment_config['trial'], 'trialCodeDir')
check_directory(experiment_config['trial'], 'trialCodeDir')
experiment_config['trial']['trialCodeDir'] = os.path.abspath(experiment_config['trial']['trialCodeDir'])
check_empty(experiment_config['trial'], 'trialGpuNum')
check_digit(experiment_config['trial'], 'trialGpuNum', 0, 100)
def validate_machinelist_content(experiment_config):
'''Validate whether meachineList in experiment_config is valid'''
check_empty(experiment_config, 'machineList')
for i, machine in enumerate(experiment_config['machineList']):
check_empty(machine, 'ip')
if machine.get('port') is None:
experiment_config['machineList'][i]['port'] = 22
else:
check_digit(machine, 'port', 0, 65535)
check_empty(machine, 'username')
check_empty(machine, 'passwd')
def validate_annotation_content(experiment_config):
'''Valid whether useAnnotation and searchSpacePath is coexist'''
if experiment_config.get('useAnnotation'):
if experiment_config.get('searchSpacePath'):
print('searchSpacePath', experiment_config.get('searchSpacePath'))
raise Exception('If you set useAnnotation=true, please leave searchSpacePath empty')
else:
# validate searchSpaceFile
check_empty(experiment_config, 'searchSpacePath')
check_file(experiment_config, 'searchSpacePath')
def validate_all_content(experiment_config):
'''Validate whether experiment_config is valid'''
validate_common_content(experiment_config)
validate_tuner_content(experiment_config)
validate_assessor_content(experiment_config)
validate_trail_content(experiment_config)
# validate_annotation_content(experiment_config)
if experiment_config['trainingServicePlatform'] == 'remote':
validate_machinelist_content(experiment_config)
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import argparse
from .launcher import create_experiment, resume_experiment
from .updater import update_searchspace, update_concurrency, update_duration
from .nnictl_utils import *
def nni_help_info(*args):
print('please run "nnictl --help" to see nnictl guidance')
def parse_args():
'''Definite the arguments users need to follow and input'''
parser = argparse.ArgumentParser(prog='nni ctl', description='use nni control')
parser.set_defaults(func=nni_help_info)
# create subparsers for args with sub values
subparsers = parser.add_subparsers()
# parse start command
parser_start = subparsers.add_parser('create', help='create a new experiment')
parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_start.add_argument('--manager', '-m', default='nnimanager', dest='manager')
parser_start.add_argument('--webuiport', '-w', default=8080, dest='webuiport')
parser_start.set_defaults(func=create_experiment)
# parse resume command
parser_resume = subparsers.add_parser('resume', help='resume a new experiment')
parser_resume.add_argument('--experiment', '-e', dest='id', help='ID of the experiment you want to resume')
parser_resume.add_argument('--manager', '-m', default='nnimanager', dest='manager')
parser_resume.add_argument('--webuiport', '-w', default=8080, dest='webuiport')
parser_resume.set_defaults(func=resume_experiment)
# parse update command
parser_updater = subparsers.add_parser('update', help='update the experiment')
#add subparsers for parser_updater
parser_updater_subparsers = parser_updater.add_subparsers()
parser_updater_searchspace = parser_updater_subparsers.add_parser('searchspace', help='update searchspace')
parser_updater_searchspace.add_argument('--filename', '-f', required=True)
parser_updater_searchspace.set_defaults(func=update_searchspace)
parser_updater_searchspace = parser_updater_subparsers.add_parser('concurrency', help='update concurrency')
parser_updater_searchspace.add_argument('--value', '-v', required=True)
parser_updater_searchspace.set_defaults(func=update_concurrency)
parser_updater_searchspace = parser_updater_subparsers.add_parser('duration', help='update duration')
parser_updater_searchspace.add_argument('--value', '-v', required=True)
parser_updater_searchspace.set_defaults(func=update_duration)
#parse stop command
parser_stop = subparsers.add_parser('stop', help='stop the experiment')
parser_stop.set_defaults(func=stop_experiment)
#parse trial command
parser_trial = subparsers.add_parser('trial', help='get trial information')
#add subparsers for parser_trial
parser_trial_subparsers = parser_trial.add_subparsers()
parser_trial_ls = parser_trial_subparsers.add_parser('ls', help='list trial jobs')
parser_trial_ls.set_defaults(func=trial_ls)
parser_trial_kill = parser_trial_subparsers.add_parser('kill', help='kill trial jobs')
parser_trial_kill.add_argument('--trialid', '-t', required=True, dest='trialid', help='the id of trial to be killed')
parser_trial_kill.set_defaults(func=trial_kill)
#TODO:finish webui function
#parse board command
parser_webui = subparsers.add_parser('webui', help='get web ui information')
#add subparsers for parser_board
parser_webui_subparsers = parser_webui.add_subparsers()
parser_webui_start = parser_webui_subparsers.add_parser('start', help='start web ui')
parser_webui_start.add_argument('--port', '-p', dest='port', default=8080, help='the port of web ui')
parser_webui_start.set_defaults(func=start_webui)
parser_webui_stop = parser_webui_subparsers.add_parser('stop', help='stop web ui')
parser_webui_stop.set_defaults(func=stop_webui)
parser_webui_url = parser_webui_subparsers.add_parser('url', help='show the url of web ui')
parser_webui_url.set_defaults(func=webui_url)
#parse experiment command
parser_experiment = subparsers.add_parser('experiment', help='get experiment information')
#add subparsers for parser_experiment
parser_experiment_subparsers = parser_experiment.add_subparsers()
parser_experiment_ls = parser_experiment_subparsers.add_parser('ls', help='list experiment')
parser_experiment_ls.set_defaults(func=list_experiment)
#parse config command
parser_config = subparsers.add_parser('config', help='get config information')
parser_config_subparsers = parser_config.add_subparsers()
parser_config_ls = parser_config_subparsers.add_parser('ls', help='list config')
parser_config_ls.set_defaults(func=get_config)
#parse restful server command
parser_rest = subparsers.add_parser('rest', help='get restful server information')
#add subparsers for parser_rest
parser_rest_subparsers = parser_rest.add_subparsers()
parser_rest_check = parser_rest_subparsers.add_parser('check', help='check restful server')
parser_rest_check.set_defaults(func=check_rest)
#parse log command
parser_log = subparsers.add_parser('log', help='get log information')
# add subparsers for parser_rest
parser_log_subparsers = parser_log.add_subparsers()
parser_log_stdout = parser_log_subparsers.add_parser('stdout', help='get stdout information')
parser_log_stdout.add_argument('--tail', '-T', dest='tail', type=int, help='get tail -100 content of stdout')
parser_log_stdout.add_argument('--head', '-H', dest='head', type=int, help='get head -100 content of stdout')
parser_log_stdout.add_argument('--path', '-p', action='store_true', default=False, help='get the path of stdout file')
parser_log_stdout.set_defaults(func=log_stdout)
parser_log_stderr = parser_log_subparsers.add_parser('stderr', help='get stderr information')
parser_log_stderr.add_argument('--tail', '-T', dest='tail', type=int, help='get tail -100 content of stderr')
parser_log_stderr.add_argument('--head', '-H', dest='head', type=int, help='get head -100 content of stderr')
parser_log_stderr.add_argument('--path', '-p', action='store_true', default=False, help='get the path of stderr file')
parser_log_stderr.set_defaults(func=log_stderr)
args = parser.parse_args()
args.func(args)
if __name__ == '__main__':
parse_args()
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
import psutil
import json
from subprocess import call, check_output
from .rest_utils import rest_get, rest_delete, check_rest_server_quick
from .config_utils import Config
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url
from .constants import STDERR_FULL_PATH, STDOUT_FULL_PATH
import time
from .common_utils import print_normal, print_error, detect_process
from .webui_utils import stop_web_ui, check_web_ui, start_web_ui
def check_rest(args):
'''check if restful server is running'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
if check_rest_server_quick(rest_port):
print_normal('Restful server is running...')
else:
print_normal('Restful server is not running...')
def stop_experiment(args):
'''Stop the experiment which is running'''
print_normal('Stoping experiment...')
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_normal('Experiment is not running...')
stop_web_ui()
return
if check_rest_server_quick(rest_port):
response = rest_delete(experiment_url(rest_port), 20)
if not response or response.status_code != 200:
print_error('Stop experiment failed!')
#sleep to wait rest handler done
time.sleep(3)
rest_pid = nni_config.get_config('restServerPid')
cmds = ['pkill', '-P', str(rest_pid)]
call(cmds)
stop_web_ui()
print_normal('Stop experiment success!')
def trial_ls(args):
'''List trial'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_rest_server_quick(rest_port):
response = rest_get(trial_jobs_url(rest_port), 20)
if response and response.status_code == 200:
print(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
else:
print_error('List trial failed...')
else:
print_error('Restful server is not running...')
def trial_kill(args):
'''List trial'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_rest_server_quick(rest_port):
response = rest_delete(trial_job_id_url(rest_port, args.trialid), 20)
if response and response.status_code == 200:
print(response.text)
else:
print_error('Kill trial job failed...')
else:
print_error('Restful server is not running...')
def list_experiment(args):
'''Get experiment information'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_rest_server_quick(rest_port):
response = rest_get(experiment_url(rest_port), 20)
if response and response.status_code == 200:
print(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
else:
print_error('List experiment failed...')
else:
print_error('Restful server is not running...')
def get_log_content(file_name, cmds):
'''use cmds to read config content'''
if os.path.exists(file_name):
rest = check_output(cmds)
print(rest.decode('utf-8'))
else:
print_normal('NULL!')
def log_internal(args, filetype):
'''internal function to call get_log_content'''
if filetype == 'stdout':
file_full_path = STDOUT_FULL_PATH
else:
file_full_path = STDERR_FULL_PATH
if args.head:
get_log_content(file_full_path, ['head', '-' + str(args.head), file_full_path])
elif args.tail:
get_log_content(file_full_path, ['tail', '-' + str(args.tail), file_full_path])
elif args.path:
print_normal('The path of stdout file is: ' + file_full_path)
else:
get_log_content(file_full_path, ['cat', file_full_path])
def log_stdout(args):
'''get stdout log'''
log_internal(args, 'stdout')
def log_stderr(args):
'''get stderr log'''
log_internal(args, 'stderr')
def get_config(args):
'''get config info'''
nni_config = Config()
print(nni_config.get_all_config())
def start_webui(args):
'''start web ui'''
# start webui
print_normal('Checking webui...')
nni_config = Config()
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
if check_web_ui():
print_error('{0} {1}'.format(' '.join(nni_config.get_config('webuiUrl')), 'is being used, please stop it first!'))
print_normal('You can use \'nnictl webui stop\' to stop old web ui process...')
else:
print_normal('Starting webui...')
webui_process = start_web_ui(args.port)
nni_config = Config()
nni_config.set_config('webuiPid', webui_process.pid)
print_normal('Starting webui success!')
print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))
def stop_webui(args):
'''stop web ui'''
print_normal('Stopping Web UI...')
if stop_web_ui():
print_normal('Web UI stopped success!')
else:
print_error('Web UI stop failed...')
def webui_url(args):
'''show the url of web ui'''
nni_config = Config()
print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import time
import requests
from .url_utils import check_status_url
def rest_put(url, data, timeout):
'''Call rest put method'''
try:
response = requests.put(url, headers={'Accept': 'application/json', 'Content-Type': 'application/json'},\
data=data, timeout=timeout)
return response
except Exception:
return None
def rest_post(url, data, timeout):
'''Call rest post method'''
try:
response = requests.post(url, headers={'Accept': 'application/json', 'Content-Type': 'application/json'},\
data=data, timeout=timeout)
return response
except Exception:
return None
def rest_get(url, timeout):
'''Call rest get method'''
try:
response = requests.get(url, timeout=timeout)
return response
except Exception:
return None
def rest_delete(url, timeout):
'''Call rest delete method'''
try:
response = requests.delete(url, timeout=timeout)
return response
except Exception:
return None
def check_rest_server(rest_port):
'''Check if restful server is ready'''
retry_count = 5
for _ in range(retry_count):
response = rest_get(check_status_url(rest_port), 20)
if response:
if response.status_code == 200:
return True
else:
return False
else:
time.sleep(3)
return False
def check_rest_server_quick(rest_port):
'''Check if restful server is ready, only check once'''
response = rest_get(check_status_url(rest_port), 5)
if response and response.status_code == 200:
return True
return False
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import json
import os
from .rest_utils import rest_put, rest_get, check_rest_server_quick
from .url_utils import experiment_url
from .config_utils import Config
from .common_utils import get_json_content
def validate_digit(value, start, end):
'''validate if a digit is valid'''
if not str(value).isdigit() or int(value) < start or int(value) > end:
raise ValueError('%s must be a digit from %s to %s' % (value, start, end))
def validate_file(path):
'''validate if a file exist'''
if not os.path.exists(path):
raise FileNotFoundError('%s is not a valid file path' % path)
def load_search_space(path):
'''load search space content'''
content = json.dumps(get_json_content(path))
if not content:
raise ValueError('searchSpace file should not be empty')
return content
def get_query_type(key):
'''get update query type'''
if key == 'trialConcurrency':
return '?update_type=TRIAL_CONCURRENCY'
if key == 'maxExecDuration':
return '?update_type=MAX_EXEC_DURATION'
if key == 'searchSpace':
return '?update_type=SEARCH_SPACE'
def update_experiment_profile(key, value):
'''call restful server to update experiment profile'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
if check_rest_server_quick(rest_port):
response = rest_get(experiment_url(rest_port), 20)
if response and response.status_code == 200:
experiment_profile = json.loads(response.text)
experiment_profile['params'][key] = value
response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), 20)
if response and response.status_code == 200:
return response
else:
print('ERROR: restful server is not running...')
return None
def update_searchspace(args):
validate_file(args.filename)
content = load_search_space(args.filename)
if update_experiment_profile('searchSpace', content):
print('INFO: update %s success!' % 'searchSpace')
else:
print('ERROR: update %s failed!' % 'searchSpace')
def update_concurrency(args):
validate_digit(args.value, 1, 1000)
if update_experiment_profile('trialConcurrency', int(args.value)):
print('INFO: update %s success!' % 'concurrency')
else:
print('ERROR: update %s failed!' % 'concurrency')
def update_duration(args):
validate_digit(args.value, 1, 999999999)
if update_experiment_profile('maxExecDuration', int(args.value)):
print('INFO: update %s success!' % 'duration')
else:
print('ERROR: update %s failed!' % 'duration')
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
BASE_URL = 'http://localhost'
API_ROOT_URL = '/api/v1/nni'
EXPERIMENT_API = '/experiment'
CLUSTER_METADATA_API = '/experiment/cluster-metadata'
CHECK_STATUS_API = '/check-status'
TRIAL_JOBS_API = '/trial-jobs'
TENSORBOARD_API = '/tensorboard'
def check_status_url(port):
'''get check_status url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CHECK_STATUS_API)
def cluster_metadata_url(port):
'''get cluster_metadata_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CLUSTER_METADATA_API)
def experiment_url(port):
'''get experiment_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPERIMENT_API)
def trial_jobs_url(port):
'''get trial_jobs url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API)
def trial_job_id_url(port, job_id):
'''get trial_jobs with id url'''
return '{0}:{1}{2}{3}/:{4}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API, job_id)
def tensorboard_url(port):
'''get tensorboard url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TENSORBOARD_API)
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import psutil
from socket import AddressFamily
from .rest_utils import rest_get
from .config_utils import Config
from subprocess import Popen, PIPE
from .common_utils import print_error, print_normal
from .constants import STDOUT_FULL_PATH, STDERR_FULL_PATH
def start_web_ui(port):
'''start web ui'''
cmds = ['serve', '-s', '-n', '/usr/share/nni/webui', '-l', str(port)]
stdout_file = open(STDOUT_FULL_PATH, 'a+')
stderr_file = open(STDERR_FULL_PATH, 'a+')
webui_process = Popen(cmds, stdout=stdout_file, stderr=stderr_file)
if webui_process.returncode is None:
webui_url_list = []
for name, info in psutil.net_if_addrs().items():
for addr in info:
if AddressFamily.AF_INET == addr.family:
webui_url_list.append('http://{}:{}'.format(addr.address, port))
nni_config = Config()
nni_config.set_config('webuiUrl', webui_url_list)
else:
print_error('Failed to start webui')
return webui_process
def stop_web_ui():
'''stop web ui'''
nni_config = Config()
webuiPid = nni_config.get_config('webuiPid')
if not webuiPid:
return False
#detect webui process first
try:
parent_process = psutil.Process(webuiPid)
if not parent_process or not parent_process.is_running():
return False
except:
return False
#then kill webui process
try:
#in some environment, there will be multi processes, kill them all
parent_process = psutil.Process(webuiPid)
child_process_list = parent_process.children(recursive=True)
for child_process in child_process_list:
if child_process.is_running():
child_process.kill()
if parent_process.is_running():
parent_process.kill()
return True
except Exception as e:
print_error(e)
return False
def check_web_ui():
'''check if web ui is alive'''
nni_config = Config()
url_list = nni_config.get_config('webuiUrl')
if not url_list:
return False
for url in url_list:
response = rest_get(url, 20)
if response and response.status_code == 200:
return True
return False
\ No newline at end of file
#!/bin/bash
python3 -m nnicmd.nnictl $@
import setuptools
setuptools.setup(
name = 'nnictl',
version = '0.0.1',
packages = setuptools.find_packages(),
python_requires = '>=3.5',
install_requires = [
'requests',
'pyyaml',
'psutil'
],
author = 'Microsoft NNI Team',
author_email = 'nni@microsoft.com',
description = 'NNI control for Neural Network Intelligence project',
license = 'MIT',
url = 'https://msrasrg.visualstudio.com/NeuralNetworkIntelligence',
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment