Unverified Commit d654eff4 authored by Tab Zhang's avatar Tab Zhang Committed by GitHub
Browse files

feature: export experiment results (#2706)

parent bcefce6a
...@@ -462,13 +462,14 @@ Debug mode will disable version check function in Trialkeeper. ...@@ -462,13 +462,14 @@ Debug mode will disable version check function in Trialkeeper.
|id| False| |ID of the experiment | |id| False| |ID of the experiment |
|--filename, -f| True| |File path of the output file | |--filename, -f| True| |File path of the output file |
|--type| True| |Type of output file, only support "csv" and "json"| |--type| True| |Type of output file, only support "csv" and "json"|
|--intermediate, -i|False||Are intermediate results included|
* Examples * Examples
> export all trial data in an experiment as json format > export all trial data in an experiment as json format
```bash ```bash
nnictl experiment export [experiment_id] --filename [file_path] --type json nnictl experiment export [experiment_id] --filename [file_path] --type json --intermediate
``` ```
* __nnictl experiment import__ * __nnictl experiment import__
...@@ -903,4 +904,3 @@ Debug mode will disable version check function in Trialkeeper. ...@@ -903,4 +904,3 @@ Debug mode will disable version check function in Trialkeeper.
```bash ```bash
nnictl --version nnictl --version
``` ```
...@@ -140,6 +140,8 @@ def parse_args(): ...@@ -140,6 +140,8 @@ def parse_args():
parser_trial_export.add_argument('id', nargs='?', help='the id of experiment') parser_trial_export.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_export.add_argument('--type', '-t', choices=['json', 'csv'], required=True, dest='type', help='target file type') parser_trial_export.add_argument('--type', '-t', choices=['json', 'csv'], required=True, dest='type', help='target file type')
parser_trial_export.add_argument('--filename', '-f', required=True, dest='path', help='target file path') parser_trial_export.add_argument('--filename', '-f', required=True, dest='path', help='target file path')
parser_trial_export.add_argument('--intermediate', '-i', action='store_true',
default=False, help='are intermediate results included')
parser_trial_export.set_defaults(func=export_trials_data) parser_trial_export.set_defaults(func=export_trials_data)
#save an NNI experiment #save an NNI experiment
parser_save_experiment = parser_experiment_subparsers.add_parser('save', help='save an experiment') parser_save_experiment = parser_experiment_subparsers.add_parser('save', help='save an experiment')
......
...@@ -16,7 +16,7 @@ from pyhdfs import HdfsClient ...@@ -16,7 +16,7 @@ from pyhdfs import HdfsClient
from nni.package_utils import get_nni_installation_path from nni.package_utils import get_nni_installation_path
from nni_annotation import expand_annotations from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url, metric_data_url
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
from .constants import NNICTL_HOME_DIR, NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \ from .constants import NNICTL_HOME_DIR, NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
...@@ -681,45 +681,64 @@ def monitor_experiment(args): ...@@ -681,45 +681,64 @@ def monitor_experiment(args):
set_monitor(False, args.time) set_monitor(False, args.time)
def export_trials_data(args): def export_trials_data(args):
'''export experiment metadata to csv '''export experiment metadata and intermediate results to json or csv
''' '''
def groupby_trial_id(intermediate_results):
sorted(intermediate_results, key=lambda x: x['timestamp'])
groupby = dict()
for content in intermediate_results:
groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
return groupby
nni_config = Config(get_config_filename(args)) nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort') rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
running, response = check_rest_server_quick(rest_port) running, response = check_rest_server_quick(rest_port)
if running: if not running:
response = rest_get(export_data_url(rest_port), 20) print_error('Restful server is not running')
if response is not None and check_response(response): return
if args.type == 'json': response = rest_get(export_data_url(rest_port), 20)
with open(args.path, 'w') as file: if response is not None and check_response(response):
file.write(response.text) content = json.loads(response.text)
elif args.type == 'csv': if args.intermediate:
content = json.loads(response.text) intermediate_results_response = rest_get(metric_data_url(rest_port), REST_TIME_OUT)
trial_records = [] if not intermediate_results_response or not check_response(intermediate_results_response):
for record in content: print_error('Error getting intermediate results.')
record_value = json.loads(record['value']) return
if not isinstance(record_value, (float, int)): intermediate_results = groupby_trial_id(json.loads(intermediate_results_response.text))
formated_record = {**record['parameter'], **record_value, **{'id': record['id']}} for record in content:
else: record['intermediate'] = intermediate_results[record['id']]
formated_record = {**record['parameter'], **{'reward': record_value, 'id': record['id']}} if args.type == 'json':
trial_records.append(formated_record) with open(args.path, 'w') as file:
if not trial_records: file.write(json.dumps(content))
print_error('No trial results collected! Please check your trial log...') elif args.type == 'csv':
exit(0) trial_records = []
with open(args.path, 'w', newline='') as file: for record in content:
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in trial_records])) formated_record = dict()
writer.writeheader() if args.intermediate:
writer.writerows(trial_records) formated_record['intermediate'] = '[' + ','.join(record['intermediate']) + ']'
else: record_value = json.loads(record['value'])
print_error('Unknown type: %s' % args.type) if not isinstance(record_value, (float, int)):
exit(1) formated_record.update({**record['parameter'], **record_value, **{'id': record['id']}})
else:
formated_record.update({**record['parameter'], **{'reward': record_value, 'id': record['id']}})
trial_records.append(formated_record)
if not trial_records:
print_error('No trial results collected! Please check your trial log...')
exit(0)
with open(args.path, 'w', newline='') as file:
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in trial_records]))
writer.writeheader()
writer.writerows(trial_records)
else: else:
print_error('Export failed...') print_error('Unknown type: %s' % args.type)
return
else: else:
print_error('Restful server is not Running') print_error('Export failed...')
def search_space_auto_gen(args): def search_space_auto_gen(args):
'''dry run trial code to generate search space file''' '''dry run trial code to generate search space file'''
...@@ -898,3 +917,4 @@ def load_experiment(args): ...@@ -898,3 +917,4 @@ def load_experiment(args):
# Step6. Cleanup temp data # Step6. Cleanup temp data
shutil.rmtree(temp_root_dir) shutil.rmtree(temp_root_dir)
...@@ -22,6 +22,12 @@ EXPORT_DATA_API = '/export-data' ...@@ -22,6 +22,12 @@ EXPORT_DATA_API = '/export-data'
TENSORBOARD_API = '/tensorboard' TENSORBOARD_API = '/tensorboard'
METRIC_DATA_API = '/metric-data'
def metric_data_url(port):
'''get metric_data url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API)
def check_status_url(port): def check_status_url(port):
'''get check_status url''' '''get check_status url'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment