Unverified Commit cdee9c36 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Fix nnictl bugs and add new feature (#75)

* fix nnictl bug

* fix nnictl create bug

* add experiment status logic

* add more information for nnictl

* fix Evolution Tuner bug

* refactor code

* fix code in updater.py

* fix nnictl --help

* fix classArgs bug

* update check response.status_code logic
parent b58666ac
...@@ -41,8 +41,8 @@ Optional('searchSpacePath'): os.path.exists, ...@@ -41,8 +41,8 @@ Optional('searchSpacePath'): os.path.exists,
'codeDir': os.path.exists, 'codeDir': os.path.exists,
'classFileName': str, 'classFileName': str,
'className': str, 'className': str,
'classArgs': { Optional('classArgs'): {
'optimize_mode': Or('maximize', 'minimize'), Optional('optimize_mode'): Or('maximize', 'minimize'),
Optional('speed'): int Optional('speed'): int
}, },
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999), Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
......
...@@ -28,10 +28,10 @@ import tempfile ...@@ -28,10 +28,10 @@ import tempfile
from nni_annotation import * from nni_annotation import *
import random import random
from .launcher_utils import validate_all_content from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response
from .url_utils import cluster_metadata_url, experiment_url from .url_utils import cluster_metadata_url, experiment_url
from .config_utils import Config from .config_utils import Config
from .common_utils import get_yml_content, get_json_content, print_error, print_normal from .common_utils import get_yml_content, get_json_content, print_error, print_normal, detect_process
from .constants import EXPERIMENT_SUCCESS_INFO, STDOUT_FULL_PATH, STDERR_FULL_PATH, LOG_DIR, REST_PORT, ERROR_INFO, NORMAL_INFO from .constants import EXPERIMENT_SUCCESS_INFO, STDOUT_FULL_PATH, STDERR_FULL_PATH, LOG_DIR, REST_PORT, ERROR_INFO, NORMAL_INFO
from .webui_utils import start_web_ui, check_web_ui from .webui_utils import start_web_ui, check_web_ui
...@@ -40,7 +40,8 @@ def start_rest_server(port, platform, mode, experiment_id=None): ...@@ -40,7 +40,8 @@ def start_rest_server(port, platform, mode, experiment_id=None):
print_normal('Checking experiment...') print_normal('Checking experiment...')
nni_config = Config() nni_config = Config()
rest_port = nni_config.get_config('restServerPort') rest_port = nni_config.get_config('restServerPort')
if rest_port and check_rest_server_quick(rest_port): running, _ = check_rest_server_quick(rest_port)
if rest_port and running:
print_error('There is an experiment running, please stop it first...') print_error('There is an experiment running, please stop it first...')
print_normal('You can use \'nnictl stop\' command to stop an experiment!') print_normal('You can use \'nnictl stop\' command to stop an experiment!')
exit(0) exit(0)
...@@ -66,7 +67,12 @@ def set_trial_config(experiment_config, port): ...@@ -66,7 +67,12 @@ def set_trial_config(experiment_config, port):
value_dict['gpuNum'] = experiment_config['trial']['gpuNum'] value_dict['gpuNum'] = experiment_config['trial']['gpuNum']
request_data['trial_config'] = value_dict request_data['trial_config'] = value_dict
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20) response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
return True if response.status_code == 200 else False if check_response(response):
return True
else:
with open(STDERR_FULL_PATH, 'a+') as fout:
fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
return False
def set_local_config(experiment_config, port): def set_local_config(experiment_config, port):
'''set local configuration''' '''set local configuration'''
...@@ -79,9 +85,11 @@ def set_remote_config(experiment_config, port): ...@@ -79,9 +85,11 @@ def set_remote_config(experiment_config, port):
request_data['machine_list'] = experiment_config['machineList'] request_data['machine_list'] = experiment_config['machineList']
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20) response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
err_message = '' err_message = ''
if not response or not response.status_code == 200: if not response or not check_response(response):
if response is not None: if response is not None:
err_message = response.text err_message = response.text
with open(STDERR_FULL_PATH, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message return False, err_message
#set trial_config #set trial_config
...@@ -117,11 +125,22 @@ def set_experiment(experiment_config, mode, port): ...@@ -117,11 +125,22 @@ def set_experiment(experiment_config, mode, port):
{'key': 'trial_config', 'value': value_dict}) {'key': 'trial_config', 'value': value_dict})
response = rest_post(experiment_url(port), json.dumps(request_data), 20) response = rest_post(experiment_url(port), json.dumps(request_data), 20)
return response if response.status_code == 200 else None if check_response(response):
return response
else:
with open(STDERR_FULL_PATH, 'a+') as fout:
fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
return None
def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=None): def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=None):
'''follow steps to start rest server and start experiment''' '''follow steps to start rest server and start experiment'''
nni_config = Config() nni_config = Config()
#Check if there is an experiment running
origin_rest_pid = nni_config.get_config('restServerPid')
if origin_rest_pid and detect_process(origin_rest_pid):
print_error('There is an experiment running, please stop it first...')
print_normal('You can use \'nnictl stop\' command to stop an experiment!')
exit(0)
# start rest server # start rest server
rest_process = start_rest_server(REST_PORT, experiment_config['trainingServicePlatform'], mode, experiment_id) rest_process = start_rest_server(REST_PORT, experiment_config['trainingServicePlatform'], mode, experiment_id)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
...@@ -144,7 +163,8 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No ...@@ -144,7 +163,8 @@ def launch_experiment(args, experiment_config, mode, webuiport, experiment_id=No
# check rest server # check rest server
print_normal('Checking restful server...') print_normal('Checking restful server...')
if check_rest_server(REST_PORT): running, _ = check_rest_server(REST_PORT)
if running:
print_normal('Restful server start success!') print_normal('Restful server start success!')
else: else:
print_error('Restful server start failed!') print_error('Restful server start failed!')
......
...@@ -99,6 +99,7 @@ def parse_tuner_content(experiment_config): ...@@ -99,6 +99,7 @@ def parse_tuner_content(experiment_config):
if experiment_config['tuner'].get('builtinTunerName') and experiment_config['tuner'].get('classArgs'): if experiment_config['tuner'].get('builtinTunerName') and experiment_config['tuner'].get('classArgs'):
experiment_config['tuner']['className'] = tuner_class_name_dict.get(experiment_config['tuner']['builtinTunerName']) experiment_config['tuner']['className'] = tuner_class_name_dict.get(experiment_config['tuner']['builtinTunerName'])
if tuner_algorithm_name_dict.get(experiment_config['tuner']['builtinTunerName']):
experiment_config['tuner']['classArgs']['algorithm_name'] = tuner_algorithm_name_dict.get(experiment_config['tuner']['builtinTunerName']) experiment_config['tuner']['classArgs']['algorithm_name'] = tuner_algorithm_name_dict.get(experiment_config['tuner']['builtinTunerName'])
elif experiment_config['tuner'].get('codeDir') and experiment_config['tuner'].get('classFileName') and experiment_config['tuner'].get('className'): elif experiment_config['tuner'].get('codeDir') and experiment_config['tuner'].get('classFileName') and experiment_config['tuner'].get('className'):
if not os.path.exists(os.path.join(experiment_config['tuner']['codeDir'], experiment_config['tuner']['classFileName'])): if not os.path.exists(os.path.join(experiment_config['tuner']['codeDir'], experiment_config['tuner']['classFileName'])):
......
...@@ -25,11 +25,11 @@ from .updater import update_searchspace, update_concurrency, update_duration ...@@ -25,11 +25,11 @@ from .updater import update_searchspace, update_concurrency, update_duration
from .nnictl_utils import * from .nnictl_utils import *
def nni_help_info(*args): def nni_help_info(*args):
print('please run "nnictl --help" to see nnictl guidance') print('please run "nnictl {positional argument} --help" to see nnictl guidance')
def parse_args(): def parse_args():
'''Definite the arguments users need to follow and input''' '''Definite the arguments users need to follow and input'''
parser = argparse.ArgumentParser(prog='nni ctl', description='use nni control') parser = argparse.ArgumentParser(prog='nnictl', description='use nnictl command to control nni experiments')
parser.set_defaults(func=nni_help_info) parser.set_defaults(func=nni_help_info)
# create subparsers for args with sub values # create subparsers for args with sub values
...@@ -95,6 +95,8 @@ def parse_args(): ...@@ -95,6 +95,8 @@ def parse_args():
parser_experiment_subparsers = parser_experiment.add_subparsers() parser_experiment_subparsers = parser_experiment.add_subparsers()
parser_experiment_show = parser_experiment_subparsers.add_parser('show', help='show the information of experiment') parser_experiment_show = parser_experiment_subparsers.add_parser('show', help='show the information of experiment')
parser_experiment_show.set_defaults(func=list_experiment) parser_experiment_show.set_defaults(func=list_experiment)
parser_experiment_status = parser_experiment_subparsers.add_parser('status', help='show the status of experiment')
parser_experiment_status.set_defaults(func=experiment_status)
#parse config command #parse config command
parser_config = subparsers.add_parser('config', help='get config information') parser_config = subparsers.add_parser('config', help='get config information')
......
...@@ -23,7 +23,7 @@ import psutil ...@@ -23,7 +23,7 @@ import psutil
import json import json
import datetime import datetime
from subprocess import call, check_output from subprocess import call, check_output
from .rest_utils import rest_get, rest_delete, check_rest_server_quick from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .config_utils import Config from .config_utils import Config
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url
from .constants import STDERR_FULL_PATH, STDOUT_FULL_PATH from .constants import STDERR_FULL_PATH, STDOUT_FULL_PATH
...@@ -47,7 +47,8 @@ def check_rest(args): ...@@ -47,7 +47,8 @@ def check_rest(args):
'''check if restful server is running''' '''check if restful server is running'''
nni_config = Config() nni_config = Config()
rest_port = nni_config.get_config('restServerPort') rest_port = nni_config.get_config('restServerPort')
if check_rest_server_quick(rest_port): running, _ = check_rest_server_quick(rest_port)
if not running:
print_normal('Restful server is running...') print_normal('Restful server is running...')
else: else:
print_normal('Restful server is not running...') print_normal('Restful server is not running...')
...@@ -62,9 +63,10 @@ def stop_experiment(args): ...@@ -62,9 +63,10 @@ def stop_experiment(args):
print_normal('Experiment is not running...') print_normal('Experiment is not running...')
stop_web_ui() stop_web_ui()
return return
if check_rest_server_quick(rest_port): running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_delete(experiment_url(rest_port), 20) response = rest_delete(experiment_url(rest_port), 20)
if not response or response.status_code != 200: if not response or not check_response(response):
print_error('Stop experiment failed!') print_error('Stop experiment failed!')
#sleep to wait rest handler done #sleep to wait rest handler done
time.sleep(3) time.sleep(3)
...@@ -82,9 +84,10 @@ def trial_ls(args): ...@@ -82,9 +84,10 @@ def trial_ls(args):
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
if check_rest_server_quick(rest_port): running, response = check_rest_server_quick(rest_port)
if running:
response = rest_get(trial_jobs_url(rest_port), 20) response = rest_get(trial_jobs_url(rest_port), 20)
if response and response.status_code == 200: if response and check_response(response):
content = json.loads(response.text) content = json.loads(response.text)
for index, value in enumerate(content): for index, value in enumerate(content):
content[index] = convert_time_stamp_to_date(value) content[index] = convert_time_stamp_to_date(value)
...@@ -102,9 +105,10 @@ def trial_kill(args): ...@@ -102,9 +105,10 @@ def trial_kill(args):
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
if check_rest_server_quick(rest_port): running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_delete(trial_job_id_url(rest_port, args.trialid), 20) response = rest_delete(trial_job_id_url(rest_port, args.trialid), 20)
if response and response.status_code == 200: if response and check_response(response):
print(response.text) print(response.text)
else: else:
print_error('Kill trial job failed...') print_error('Kill trial job failed...')
...@@ -119,9 +123,10 @@ def list_experiment(args): ...@@ -119,9 +123,10 @@ def list_experiment(args):
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
if check_rest_server_quick(rest_port): running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), 20) response = rest_get(experiment_url(rest_port), 20)
if response and response.status_code == 200: if response and check_response(response):
content = convert_time_stamp_to_date(json.loads(response.text)) content = convert_time_stamp_to_date(json.loads(response.text))
print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':'))) print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
else: else:
...@@ -129,6 +134,16 @@ def list_experiment(args): ...@@ -129,6 +134,16 @@ def list_experiment(args):
else: else:
print_error('Restful server is not running...') print_error('Restful server is not running...')
def experiment_status(args):
'''Show the status of experiment'''
nni_config = Config()
rest_port = nni_config.get_config('restServerPort')
result, response = check_rest_server_quick(rest_port)
if not result:
print_normal('Restful server is not running...')
else:
print(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
def get_log_content(file_name, cmds): def get_log_content(file_name, cmds):
'''use cmds to read config content''' '''use cmds to read config content'''
if os.path.exists(file_name): if os.path.exists(file_name):
......
...@@ -64,16 +64,22 @@ def check_rest_server(rest_port): ...@@ -64,16 +64,22 @@ def check_rest_server(rest_port):
response = rest_get(check_status_url(rest_port), 20) response = rest_get(check_status_url(rest_port), 20)
if response: if response:
if response.status_code == 200: if response.status_code == 200:
return True return True, response
else: else:
return False return False, response
else: else:
time.sleep(3) time.sleep(3)
return False return False, response
def check_rest_server_quick(rest_port): def check_rest_server_quick(rest_port):
'''Check if restful server is ready, only check once''' '''Check if restful server is ready, only check once'''
response = rest_get(check_status_url(rest_port), 5) response = rest_get(check_status_url(rest_port), 5)
if response and response.status_code == 200:
return True, response
return False, None
def check_response(response):
'''Check if a response is success according to status_code'''
if response and response.status_code == 200: if response and response.status_code == 200:
return True return True
return False return False
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
import json import json
import os import os
from .rest_utils import rest_put, rest_get, check_rest_server_quick from .rest_utils import rest_put, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url from .url_utils import experiment_url
from .config_utils import Config from .config_utils import Config
from .common_utils import get_json_content from .common_utils import get_json_content
...@@ -56,13 +56,14 @@ def update_experiment_profile(key, value): ...@@ -56,13 +56,14 @@ def update_experiment_profile(key, value):
'''call restful server to update experiment profile''' '''call restful server to update experiment profile'''
nni_config = Config() nni_config = Config()
rest_port = nni_config.get_config('restServerPort') rest_port = nni_config.get_config('restServerPort')
if check_rest_server_quick(rest_port): running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), 20) response = rest_get(experiment_url(rest_port), 20)
if response and response.status_code == 200: if response and check_response(response):
experiment_profile = json.loads(response.text) experiment_profile = json.loads(response.text)
experiment_profile['params'][key] = value experiment_profile['params'][key] = value
response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), 20) response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), 20)
if response and response.status_code == 200: if response and check_response(response):
return response return response
else: else:
print('ERROR: restful server is not running...') print('ERROR: restful server is not running...')
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
import os import os
import psutil import psutil
from socket import AddressFamily from socket import AddressFamily
from subprocess import Popen, PIPE from subprocess import Popen, PIPE, call
from .rest_utils import rest_get from .rest_utils import rest_get, check_response
from .config_utils import Config from .config_utils import Config
from .common_utils import print_error, print_normal from .common_utils import print_error, print_normal
from .constants import STDOUT_FULL_PATH, STDERR_FULL_PATH from .constants import STDOUT_FULL_PATH, STDERR_FULL_PATH
...@@ -71,6 +71,8 @@ def stop_web_ui(): ...@@ -71,6 +71,8 @@ def stop_web_ui():
child_process.kill() child_process.kill()
if parent_process.is_running(): if parent_process.is_running():
parent_process.kill() parent_process.kill()
cmds = ['pkill', '-P', str(webuiPid)]
call(cmds)
return True return True
except Exception as e: except Exception as e:
print_error(e) print_error(e)
...@@ -84,6 +86,6 @@ def check_web_ui(): ...@@ -84,6 +86,6 @@ def check_web_ui():
return False return False
for url in url_list: for url in url_list:
response = rest_get(url, 3) response = rest_get(url, 3)
if response and response.status_code == 200: if response and check_response(response):
return True return True
return False return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment