Unverified Commit f68ba4a6 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

nnictl get port and pid from .experiment (#3235)

parent bcda469f
......@@ -509,6 +509,11 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# Deal with annotation
if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
......@@ -546,11 +551,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
# start a new experiment
print_normal('Starting experiment...')
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# set debug configuration
if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug
......@@ -613,8 +613,7 @@ def create_experiment(args):
try:
launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception:
nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid')
restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
if restServerPid:
kill_command(restServerPid)
print_error(exception)
......@@ -646,8 +645,7 @@ def manage_stopped_experiment(args, mode):
try:
launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception:
nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid')
restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
if restServerPid:
kill_command(restServerPid)
print_error(exception)
......
......@@ -50,11 +50,9 @@ def update_experiment():
for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED':
nni_config = Config(key)
rest_pid = nni_config.get_config('restServerPid')
rest_pid = experiment_dict[key].get('pid')
if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED')
experiment_config.update_experiment(key, 'port', None)
continue
def check_experiment_id(args, update=True):
......@@ -83,10 +81,10 @@ def check_experiment_id(args, update=True):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'],
experiment_dict[key]['port'],
experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'],
experiment_dict[key]['endTime'])
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1)
elif not running_experiment_list:
......@@ -130,7 +128,7 @@ def parse_ids(args):
return running_experiment_list
if args.port is not None:
for key in running_experiment_list:
if experiment_dict[key]['port'] == args.port:
if experiment_dict[key].get('port') == args.port:
result_list.append(key)
if args.id and result_list and args.id != result_list[0]:
print_error('Experiment id and resful server port not match')
......@@ -143,10 +141,10 @@ def parse_ids(args):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'],
experiment_dict[key]['port'],
experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'],
experiment_dict[key]['endTime'])
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1)
else:
......@@ -186,7 +184,7 @@ def get_experiment_port(args):
exit(1)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
return experiment_dict[experiment_id]['port']
return experiment_dict[experiment_id].get('port')
def convert_time_stamp_to_date(content):
'''Convert time stamp to date time format'''
......@@ -202,8 +200,9 @@ def convert_time_stamp_to_date(content):
def check_rest(args):
'''check if restful server is running'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
if running:
print_normal('Restful server is running...')
......@@ -220,18 +219,19 @@ def stop_experiment(args):
if experiment_id_list:
for experiment_id in experiment_id_list:
print_normal('Stopping experiment %s' % experiment_id)
nni_config = Config(experiment_id)
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_pid = experiment_dict.get(experiment_id).get('pid')
if rest_pid:
kill_command(rest_pid)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
tensorboard_pid_list = experiment_dict.get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list:
try:
kill_command(tensorboard_pid)
except Exception as exception:
print_error(exception)
nni_config.set_config('tensorboardPidList', [])
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop experiment success.')
def trial_ls(args):
......@@ -250,9 +250,10 @@ def trial_ls(args):
if args.head and args.tail:
print_error('Head and tail cannot be set at the same time.')
return
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -281,9 +282,10 @@ def trial_ls(args):
def trial_kill(args):
'''List trial'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -312,9 +314,10 @@ def trial_codegen(args):
def list_experiment(args):
'''Get experiment information'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -333,8 +336,9 @@ def list_experiment(args):
def experiment_status(args):
'''Show the status of experiment'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
result, response = check_rest_server_quick(rest_port)
if not result:
print_normal('Restful server is not running...')
......@@ -620,12 +624,12 @@ def platform_clean(args):
break
if platform == 'remote':
machine_list = config_content.get('machineList')
remote_clean(machine_list, None)
remote_clean(machine_list)
elif platform == 'pai':
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None)
hdfs_clean(host, user_name, output_dir)
print_normal('Done.')
def experiment_list(args):
......@@ -651,7 +655,7 @@ def experiment_list(args):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'],
experiment_dict[key]['port'],
experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
......@@ -752,9 +756,10 @@ def export_trials_data(args):
groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
return groupby
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
......
......@@ -70,7 +70,7 @@ def format_tensorboard_log_path(path_list):
new_path_list.append('name%d:%s' % (index + 1, value))
return ','.join(new_path_list)
def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path):
'''call cmds to start tensorboard process in local machine'''
if detect_port(args.port):
print_error('Port %s is used by another process, please reset port!' % str(args.port))
......@@ -83,20 +83,19 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
url_list = get_local_urls(args.port)
print_green('Start tensorboard success!')
print_normal('Tensorboard urls: ' + ' '.join(url_list))
tensorboard_process_pid_list = nni_config.get_config('tensorboardPidList')
experiment_config = Experiments()
tensorboard_process_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_process_pid_list is None:
tensorboard_process_pid_list = [tensorboard_process.pid]
else:
tensorboard_process_pid_list.append(tensorboard_process.pid)
nni_config.set_config('tensorboardPidList', tensorboard_process_pid_list)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list)
def stop_tensorboard(args):
'''stop tensorboard'''
experiment_id = check_experiment_id(args)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
nni_config = Config(experiment_id)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
tensorboard_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list:
try:
......@@ -104,7 +103,7 @@ def stop_tensorboard(args):
call(cmds)
except Exception as exception:
print_error(exception)
nni_config.set_config('tensorboardPidList', [])
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop tensorboard success!')
else:
print_error('No tensorboard configuration!')
......@@ -164,4 +163,4 @@ def start_tensorboard(args):
os.makedirs(temp_nni_path, exist_ok=True)
path_list = get_path_list(args, nni_config, trial_content, temp_nni_path)
start_tensorboard_process(args, nni_config, path_list, temp_nni_path)
\ No newline at end of file
start_tensorboard_process(args, experiment_id, path_list, temp_nni_path)
......@@ -5,7 +5,7 @@ import json
import os
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url, import_data_url
from .config_utils import Config
from .config_utils import Config, Experiments
from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import get_experiment_port, get_config_filename, detect_process
from .launcher_utils import parse_time
......@@ -58,8 +58,9 @@ def get_query_type(key):
def update_experiment_profile(args, key, value):
'''call restful server to update experiment profile'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
......
......@@ -480,7 +480,6 @@ class NNIManager implements Manager {
}
await this.storeExperimentProfile();
this.setStatus('STOPPED');
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'port', undefined);
}
private async periodicallyUpdateExecDuration(): Promise<void> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment