Unverified Commit f68ba4a6 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

nnictl get port and pid from .experiment (#3235)

parent bcda469f
...@@ -509,6 +509,11 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -509,6 +509,11 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, experiment_id, foreground, log_dir, log_level) mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation') path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
...@@ -546,11 +551,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -546,11 +551,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
# start a new experiment # start a new experiment
print_normal('Starting experiment...') print_normal('Starting experiment...')
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# set debug configuration # set debug configuration
if mode != 'view' and experiment_config.get('debug') is None: if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug experiment_config['debug'] = args.debug
...@@ -613,8 +613,7 @@ def create_experiment(args): ...@@ -613,8 +613,7 @@ def create_experiment(args):
try: try:
launch_experiment(args, experiment_config, 'new', experiment_id) launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception: except Exception as exception:
nni_config = Config(experiment_id) restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
restServerPid = nni_config.get_config('restServerPid')
if restServerPid: if restServerPid:
kill_command(restServerPid) kill_command(restServerPid)
print_error(exception) print_error(exception)
...@@ -646,8 +645,7 @@ def manage_stopped_experiment(args, mode): ...@@ -646,8 +645,7 @@ def manage_stopped_experiment(args, mode):
try: try:
launch_experiment(args, experiment_config, mode, experiment_id) launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception: except Exception as exception:
nni_config = Config(experiment_id) restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
restServerPid = nni_config.get_config('restServerPid')
if restServerPid: if restServerPid:
kill_command(restServerPid) kill_command(restServerPid)
print_error(exception) print_error(exception)
......
...@@ -50,11 +50,9 @@ def update_experiment(): ...@@ -50,11 +50,9 @@ def update_experiment():
for key in experiment_dict.keys(): for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict): if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED': if experiment_dict[key].get('status') != 'STOPPED':
nni_config = Config(key) rest_pid = experiment_dict[key].get('pid')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED') experiment_config.update_experiment(key, 'status', 'STOPPED')
experiment_config.update_experiment(key, 'port', None)
continue continue
def check_experiment_id(args, update=True): def check_experiment_id(args, update=True):
...@@ -83,10 +81,10 @@ def check_experiment_id(args, update=True): ...@@ -83,10 +81,10 @@ def check_experiment_id(args, update=True):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key, experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'), experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'], experiment_dict[key]['status'],
experiment_dict[key]['port'], experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'), experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
elif not running_experiment_list: elif not running_experiment_list:
...@@ -130,7 +128,7 @@ def parse_ids(args): ...@@ -130,7 +128,7 @@ def parse_ids(args):
return running_experiment_list return running_experiment_list
if args.port is not None: if args.port is not None:
for key in running_experiment_list: for key in running_experiment_list:
if experiment_dict[key]['port'] == args.port: if experiment_dict[key].get('port') == args.port:
result_list.append(key) result_list.append(key)
if args.id and result_list and args.id != result_list[0]: if args.id and result_list and args.id != result_list[0]:
print_error('Experiment id and resful server port not match') print_error('Experiment id and resful server port not match')
...@@ -143,10 +141,10 @@ def parse_ids(args): ...@@ -143,10 +141,10 @@ def parse_ids(args):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key, experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'), experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'], experiment_dict[key]['status'],
experiment_dict[key]['port'], experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'), experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
else: else:
...@@ -186,7 +184,7 @@ def get_experiment_port(args): ...@@ -186,7 +184,7 @@ def get_experiment_port(args):
exit(1) exit(1)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
return experiment_dict[experiment_id]['port'] return experiment_dict[experiment_id].get('port')
def convert_time_stamp_to_date(content): def convert_time_stamp_to_date(content):
'''Convert time stamp to date time format''' '''Convert time stamp to date time format'''
...@@ -202,8 +200,9 @@ def convert_time_stamp_to_date(content): ...@@ -202,8 +200,9 @@ def convert_time_stamp_to_date(content):
def check_rest(args): def check_rest(args):
'''check if restful server is running''' '''check if restful server is running'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
print_normal('Restful server is running...') print_normal('Restful server is running...')
...@@ -220,18 +219,19 @@ def stop_experiment(args): ...@@ -220,18 +219,19 @@ def stop_experiment(args):
if experiment_id_list: if experiment_id_list:
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
print_normal('Stopping experiment %s' % experiment_id) print_normal('Stopping experiment %s' % experiment_id)
nni_config = Config(experiment_id) experiment_config = Experiments()
rest_pid = nni_config.get_config('restServerPid') experiment_dict = experiment_config.get_all_experiments()
rest_pid = experiment_dict.get(experiment_id).get('pid')
if rest_pid: if rest_pid:
kill_command(rest_pid) kill_command(rest_pid)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList') tensorboard_pid_list = experiment_dict.get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list:
try: try:
kill_command(tensorboard_pid) kill_command(tensorboard_pid)
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
nni_config.set_config('tensorboardPidList', []) experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop experiment success.') print_normal('Stop experiment success.')
def trial_ls(args): def trial_ls(args):
...@@ -250,9 +250,10 @@ def trial_ls(args): ...@@ -250,9 +250,10 @@ def trial_ls(args):
if args.head and args.tail: if args.head and args.tail:
print_error('Head and tail cannot be set at the same time.') print_error('Head and tail cannot be set at the same time.')
return return
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_pid = nni_config.get_config('restServerPid') rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -281,9 +282,10 @@ def trial_ls(args): ...@@ -281,9 +282,10 @@ def trial_ls(args):
def trial_kill(args): def trial_kill(args):
'''List trial''' '''List trial'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_pid = nni_config.get_config('restServerPid') rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -312,9 +314,10 @@ def trial_codegen(args): ...@@ -312,9 +314,10 @@ def trial_codegen(args):
def list_experiment(args): def list_experiment(args):
'''Get experiment information''' '''Get experiment information'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_pid = nni_config.get_config('restServerPid') rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -333,8 +336,9 @@ def list_experiment(args): ...@@ -333,8 +336,9 @@ def list_experiment(args):
def experiment_status(args): def experiment_status(args):
'''Show the status of experiment''' '''Show the status of experiment'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
result, response = check_rest_server_quick(rest_port) result, response = check_rest_server_quick(rest_port)
if not result: if not result:
print_normal('Restful server is not running...') print_normal('Restful server is not running...')
...@@ -620,12 +624,12 @@ def platform_clean(args): ...@@ -620,12 +624,12 @@ def platform_clean(args):
break break
if platform == 'remote': if platform == 'remote':
machine_list = config_content.get('machineList') machine_list = config_content.get('machineList')
remote_clean(machine_list, None) remote_clean(machine_list)
elif platform == 'pai': elif platform == 'pai':
host = config_content.get('paiConfig').get('host') host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName') user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir') output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None) hdfs_clean(host, user_name, output_dir)
print_normal('Done.') print_normal('Done.')
def experiment_list(args): def experiment_list(args):
...@@ -651,7 +655,7 @@ def experiment_list(args): ...@@ -651,7 +655,7 @@ def experiment_list(args):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key, experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'), experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'], experiment_dict[key]['status'],
experiment_dict[key]['port'], experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'), experiment_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
...@@ -752,9 +756,10 @@ def export_trials_data(args): ...@@ -752,9 +756,10 @@ def export_trials_data(args):
groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data'])) groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
return groupby return groupby
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_pid = nni_config.get_config('restServerPid') rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
......
...@@ -70,7 +70,7 @@ def format_tensorboard_log_path(path_list): ...@@ -70,7 +70,7 @@ def format_tensorboard_log_path(path_list):
new_path_list.append('name%d:%s' % (index + 1, value)) new_path_list.append('name%d:%s' % (index + 1, value))
return ','.join(new_path_list) return ','.join(new_path_list)
def start_tensorboard_process(args, nni_config, path_list, temp_nni_path): def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path):
'''call cmds to start tensorboard process in local machine''' '''call cmds to start tensorboard process in local machine'''
if detect_port(args.port): if detect_port(args.port):
print_error('Port %s is used by another process, please reset port!' % str(args.port)) print_error('Port %s is used by another process, please reset port!' % str(args.port))
...@@ -83,20 +83,19 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path): ...@@ -83,20 +83,19 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
url_list = get_local_urls(args.port) url_list = get_local_urls(args.port)
print_green('Start tensorboard success!') print_green('Start tensorboard success!')
print_normal('Tensorboard urls: ' + ' '.join(url_list)) print_normal('Tensorboard urls: ' + ' '.join(url_list))
tensorboard_process_pid_list = nni_config.get_config('tensorboardPidList') experiment_config = Experiments()
tensorboard_process_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_process_pid_list is None: if tensorboard_process_pid_list is None:
tensorboard_process_pid_list = [tensorboard_process.pid] tensorboard_process_pid_list = [tensorboard_process.pid]
else: else:
tensorboard_process_pid_list.append(tensorboard_process.pid) tensorboard_process_pid_list.append(tensorboard_process.pid)
nni_config.set_config('tensorboardPidList', tensorboard_process_pid_list) experiment_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list)
def stop_tensorboard(args): def stop_tensorboard(args):
'''stop tensorboard''' '''stop tensorboard'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() tensorboard_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
nni_config = Config(experiment_id)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list:
try: try:
...@@ -104,7 +103,7 @@ def stop_tensorboard(args): ...@@ -104,7 +103,7 @@ def stop_tensorboard(args):
call(cmds) call(cmds)
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
nni_config.set_config('tensorboardPidList', []) experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop tensorboard success!') print_normal('Stop tensorboard success!')
else: else:
print_error('No tensorboard configuration!') print_error('No tensorboard configuration!')
...@@ -164,4 +163,4 @@ def start_tensorboard(args): ...@@ -164,4 +163,4 @@ def start_tensorboard(args):
os.makedirs(temp_nni_path, exist_ok=True) os.makedirs(temp_nni_path, exist_ok=True)
path_list = get_path_list(args, nni_config, trial_content, temp_nni_path) path_list = get_path_list(args, nni_config, trial_content, temp_nni_path)
start_tensorboard_process(args, nni_config, path_list, temp_nni_path) start_tensorboard_process(args, experiment_id, path_list, temp_nni_path)
\ No newline at end of file
...@@ -5,7 +5,7 @@ import json ...@@ -5,7 +5,7 @@ import json
import os import os
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url, import_data_url from .url_utils import experiment_url, import_data_url
from .config_utils import Config from .config_utils import Config, Experiments
from .common_utils import get_json_content, print_normal, print_error, print_warning from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import get_experiment_port, get_config_filename, detect_process from .nnictl_utils import get_experiment_port, get_config_filename, detect_process
from .launcher_utils import parse_time from .launcher_utils import parse_time
...@@ -58,8 +58,9 @@ def get_query_type(key): ...@@ -58,8 +58,9 @@ def get_query_type(key):
def update_experiment_profile(args, key, value): def update_experiment_profile(args, key, value):
'''call restful server to update experiment profile''' '''call restful server to update experiment profile'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT) response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
......
...@@ -480,7 +480,6 @@ class NNIManager implements Manager { ...@@ -480,7 +480,6 @@ class NNIManager implements Manager {
} }
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.setStatus('STOPPED'); this.setStatus('STOPPED');
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'port', undefined);
} }
private async periodicallyUpdateExecDuration(): Promise<void> { private async periodicallyUpdateExecDuration(): Promise<void> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment