Unverified Commit f075aab0 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Fix nnictl experiment list command (#934)

parent 21b48d29
......@@ -79,7 +79,7 @@ class Experiments:
self.experiments[id]['port'] = port
self.experiments[id]['startTime'] = time
self.experiments[id]['endTime'] = 'N/A'
self.experiments[id]['status'] = 'running'
self.experiments[id]['status'] = 'INITIALIZED'
self.experiments[id]['fileName'] = file_name
self.experiments[id]['platform'] = platform
self.write_file()
......
......@@ -30,6 +30,8 @@ WARNING_INFO = 'WARNING: %s'
DEFAULT_REST_PORT = 8080
REST_TIME_OUT = 20
EXPERIMENT_SUCCESS_INFO = '\033[1;32;32mSuccessfully started experiment!\n\033[0m' \
'-----------------------------------------------------------------------\n' \
'The experiment id is %s\n'\
......
......@@ -139,7 +139,7 @@ def set_trial_config(experiment_config, port, config_file_name):
'''set trial configuration'''
request_data = dict()
request_data['trial_config'] = experiment_config['trial']
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
if check_response(response):
return True
else:
......@@ -159,7 +159,7 @@ def set_remote_config(experiment_config, port, config_file_name):
#set machine_list
request_data = dict()
request_data['machine_list'] = experiment_config['machineList']
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 20)
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
err_message = ''
if not response or not check_response(response):
if response is not None:
......@@ -180,7 +180,7 @@ def setNNIManagerIp(experiment_config, port, config_file_name):
return True, None
ip_config_dict = dict()
ip_config_dict['nni_manager_ip'] = { 'nniManagerIp' : experiment_config['nniManagerIp'] }
response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), 20)
response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
......@@ -195,7 +195,7 @@ def set_pai_config(experiment_config, port, config_file_name):
'''set pai configuration'''
pai_config_data = dict()
pai_config_data['pai_config'] = experiment_config['paiConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(pai_config_data), 20)
response = rest_put(cluster_metadata_url(port), json.dumps(pai_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
......@@ -214,7 +214,7 @@ def set_kubeflow_config(experiment_config, port, config_file_name):
'''set kubeflow configuration'''
kubeflow_config_data = dict()
kubeflow_config_data['kubeflow_config'] = experiment_config['kubeflowConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), 20)
response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
......@@ -233,7 +233,7 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name):
'''set kubeflow configuration'''
frameworkcontroller_config_data = dict()
frameworkcontroller_config_data['frameworkcontroller_config'] = experiment_config['frameworkcontrollerConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), 20)
response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
......@@ -304,7 +304,7 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
response = rest_post(experiment_url(port), json.dumps(request_data), 20)
response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT)
if check_response(response):
return response
else:
......@@ -488,7 +488,7 @@ def resume_experiment(args):
if experiment_dict.get(args.id) is None:
print_error('Id %s not exist!' % args.id)
exit(1)
if experiment_dict[args.id]['status'] == 'running':
if experiment_dict[args.id]['status'] != 'STOPPED':
print_error('Experiment %s is running!' % args.id)
exit(1)
experiment_id = args.id
......
......@@ -28,10 +28,25 @@ from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_re
from .config_utils import Config, Experiments
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url
from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .common_utils import print_normal, print_error, print_warning, detect_process
def update_experiment_status():
def get_experiment_time(port):
'''get the startTime and endTime of an experiment'''
response = rest_get(experiment_url(port), REST_TIME_OUT)
if response and check_response(response):
content = convert_time_stamp_to_date(json.loads(response.text))
return content.get('startTime'), content.get('endTime')
return None, None
def get_experiment_status(port):
'''get the status of an experiment'''
result, response = check_rest_server_quick(port)
if result:
return json.loads(response.text).get('status')
return None
def update_experiment():
'''Update the experiment status in config file'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
......@@ -39,16 +54,26 @@ def update_experiment_status():
return None
for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') == 'running':
if experiment_dict[key].get('status') != 'STOPPED':
nni_config = Config(experiment_dict[key]['fileName'])
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'stopped')
experiment_config.update_experiment(key, 'status', 'STOPPED')
continue
rest_port = nni_config.get_config('restServerPort')
startTime, endTime = get_experiment_time(rest_port)
if startTime:
experiment_config.update_experiment(key, 'startTime', startTime)
if endTime:
experiment_config.update_experiment(key, 'endTime', endTime)
status = get_experiment_status(rest_port)
if status:
experiment_config.update_experiment(key, 'status', status)
def check_experiment_id(args):
'''check if the id is valid
'''
update_experiment_status()
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
......@@ -58,13 +83,13 @@ def check_experiment_id(args):
running_experiment_list = []
for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') == 'running':
if experiment_dict[key].get('status') != 'STOPPED':
running_experiment_list.append(key)
elif isinstance(experiment_dict[key], list):
# if the config file is old version, remove the configuration from file
experiment_config.remove_experiment(key)
if len(running_experiment_list) > 1:
print_error('There are multiple experiments running, please set the experiment id...')
print_error('There are multiple experiments, please set the experiment id...')
experiment_information = ""
for key in running_experiment_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], \
......@@ -94,7 +119,7 @@ def parse_ids(args):
5.If the id does not exist but match the prefix of an experiment id, nnictl will return the matched id
6.If the id does not exist but match multiple prefix of the experiment ids, nnictl will give id information
'''
update_experiment_status()
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
......@@ -104,14 +129,14 @@ def parse_ids(args):
running_experiment_list = []
for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') == 'running':
if experiment_dict[key].get('status') != 'STOPPED':
running_experiment_list.append(key)
elif isinstance(experiment_dict[key], list):
# if the config file is old version, remove the configuration from file
experiment_config.remove_experiment(key)
if not args.id:
if len(running_experiment_list) > 1:
print_error('There are multiple experiments running, please set the experiment id...')
print_error('There are multiple experiments, please set the experiment id...')
experiment_information = ""
for key in running_experiment_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], \
......@@ -207,7 +232,7 @@ def stop_experiment(args):
print_error(exception)
nni_config.set_config('tensorboardPidList', [])
print_normal('Stop experiment success!')
experiment_config.update_experiment(experiment_id, 'status', 'stopped')
experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
......@@ -221,7 +246,7 @@ def trial_ls(args):
return
running, response = check_rest_server_quick(rest_port)
if running:
response = rest_get(trial_jobs_url(rest_port), 20)
response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
for index, value in enumerate(content):
......@@ -242,7 +267,7 @@ def trial_kill(args):
return
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_delete(trial_job_id_url(rest_port, args.id), 20)
response = rest_delete(trial_job_id_url(rest_port, args.id), REST_TIME_OUT)
if response and check_response(response):
print(response.text)
else:
......@@ -260,7 +285,7 @@ def list_experiment(args):
return
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), 20)
response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
if response and check_response(response):
content = convert_time_stamp_to_date(json.loads(response.text))
print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
......@@ -322,7 +347,7 @@ def log_trial(args):
return
running, response = check_rest_server_quick(rest_port)
if running:
response = rest_get(trial_jobs_url(rest_port), 20)
response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
for trial in content:
......@@ -362,18 +387,20 @@ def experiment_list(args):
if not experiment_dict:
print('There is no experiment running...')
exit(1)
update_experiment()
experiment_id_list = []
if args.all and args.all == 'all':
for key in experiment_dict.keys():
experiment_id_list.append(key)
else:
for key in experiment_dict.keys():
if experiment_dict[key]['status'] == 'running':
if experiment_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key)
if not experiment_id_list:
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list all\' to list all stopped experiments!')
experiment_information = ""
for key in experiment_id_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
......@@ -382,8 +409,8 @@ def get_time_interval(time1, time2):
'''get the interval of two times'''
try:
#convert time to timestamp
time1 = time.mktime(time.strptime(time1, '%Y-%m-%d %H:%M:%S'))
time2 = time.mktime(time.strptime(time2, '%Y-%m-%d %H:%M:%S'))
time1 = time.mktime(time.strptime(time1, '%Y/%m/%d %H:%M:%S'))
time2 = time.mktime(time.strptime(time2, '%Y/%m/%d %H:%M:%S'))
seconds = (datetime.datetime.fromtimestamp(time2) - datetime.datetime.fromtimestamp(time1)).seconds
#convert seconds to day:hour:minute:second
days = seconds / 86400
......@@ -403,21 +430,21 @@ def show_experiment_info():
if not experiment_dict:
print('There is no experiment running...')
exit(1)
update_experiment()
experiment_id_list = []
for key in experiment_dict.keys():
if experiment_dict[key]['status'] == 'running':
if experiment_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key)
if not experiment_id_list:
print_warning('There is no experiment running...')
return
for key in experiment_id_list:
current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], get_time_interval(experiment_dict[key]['startTime'], current_time)))
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])))
print(TRIAL_MONITOR_HEAD)
running, response = check_rest_server_quick(experiment_dict[key]['port'])
if running:
response = rest_get(trial_jobs_url(experiment_dict[key]['port']), 20)
response = rest_get(trial_jobs_url(experiment_dict[key]['port']), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
for index, value in enumerate(content):
......@@ -433,7 +460,7 @@ def monitor_experiment(args):
while True:
try:
os.system('clear')
update_experiment_status()
update_experiment()
show_experiment_info()
time.sleep(args.time)
except KeyboardInterrupt:
......
......@@ -22,6 +22,7 @@
import time
import requests
from .url_utils import check_status_url
from .constants import REST_TIME_OUT
def rest_put(url, data, timeout):
'''Call rest put method'''
......@@ -61,7 +62,7 @@ def check_rest_server(rest_port):
'''Check if restful server is ready'''
retry_count = 5
for _ in range(retry_count):
response = rest_get(check_status_url(rest_port), 20)
response = rest_get(check_status_url(rest_port), REST_TIME_OUT)
if response:
if response.status_code == 200:
return True, response
......
......@@ -144,7 +144,7 @@ def start_tensorboard(args):
running, response = check_rest_server_quick(rest_port)
trial_content = None
if running:
response = rest_get(trial_jobs_url(rest_port), 20)
response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
if response and check_response(response):
trial_content = json.loads(response.text)
else:
......
......@@ -27,6 +27,7 @@ from .config_utils import Config
from .common_utils import get_json_content
from .nnictl_utils import check_experiment_id, get_experiment_port, get_config_filename
from .launcher_utils import parse_time
from .constants import REST_TIME_OUT
def validate_digit(value, start, end):
'''validate if a digit is valid'''
......@@ -62,11 +63,11 @@ def update_experiment_profile(args, key, value):
rest_port = nni_config.get_config('restServerPort')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), 20)
response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
if response and check_response(response):
experiment_profile = json.loads(response.text)
experiment_profile['params'][key] = value
response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), 20)
response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT)
if response and check_response(response):
return response
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment