Unverified Commit d654eff4 authored by Tab Zhang's avatar Tab Zhang Committed by GitHub
Browse files

feature: export experiment results (#2706)

parent bcefce6a
......@@ -462,13 +462,14 @@ Debug mode will disable version check function in Trialkeeper.
|id| False| |ID of the experiment |
|--filename, -f| True| |File path of the output file |
|--type| True| |Type of output file, only support "csv" and "json"|
|--intermediate, -i|False||Are intermediate results included|
* Examples
> export all trial data in an experiment as json format
```bash
nnictl experiment export [experiment_id] --filename [file_path] --type json
nnictl experiment export [experiment_id] --filename [file_path] --type json --intermediate
```
* __nnictl experiment import__
......@@ -903,4 +904,3 @@ Debug mode will disable version check function in Trialkeeper.
```bash
nnictl --version
```
......@@ -140,6 +140,8 @@ def parse_args():
parser_trial_export.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_export.add_argument('--type', '-t', choices=['json', 'csv'], required=True, dest='type', help='target file type')
parser_trial_export.add_argument('--filename', '-f', required=True, dest='path', help='target file path')
parser_trial_export.add_argument('--intermediate', '-i', action='store_true',
default=False, help='are intermediate results included')
parser_trial_export.set_defaults(func=export_trials_data)
#save an NNI experiment
parser_save_experiment = parser_experiment_subparsers.add_parser('save', help='save an experiment')
......
......@@ -16,7 +16,7 @@ from pyhdfs import HdfsClient
from nni.package_utils import get_nni_installation_path
from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url, metric_data_url
from .config_utils import Config, Experiments
from .constants import NNICTL_HOME_DIR, NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
......@@ -681,45 +681,64 @@ def monitor_experiment(args):
set_monitor(False, args.time)
def export_trials_data(args):
'''export experiment metadata to csv
'''export experiment metadata and intermediate results to json or csv
'''
def groupby_trial_id(intermediate_results):
sorted(intermediate_results, key=lambda x: x['timestamp'])
groupby = dict()
for content in intermediate_results:
groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
return groupby
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, response = check_rest_server_quick(rest_port)
if running:
response = rest_get(export_data_url(rest_port), 20)
if response is not None and check_response(response):
if args.type == 'json':
with open(args.path, 'w') as file:
file.write(response.text)
elif args.type == 'csv':
content = json.loads(response.text)
trial_records = []
for record in content:
record_value = json.loads(record['value'])
if not isinstance(record_value, (float, int)):
formated_record = {**record['parameter'], **record_value, **{'id': record['id']}}
else:
formated_record = {**record['parameter'], **{'reward': record_value, 'id': record['id']}}
trial_records.append(formated_record)
if not trial_records:
print_error('No trial results collected! Please check your trial log...')
exit(0)
with open(args.path, 'w', newline='') as file:
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in trial_records]))
writer.writeheader()
writer.writerows(trial_records)
else:
print_error('Unknown type: %s' % args.type)
exit(1)
if not running:
print_error('Restful server is not running')
return
response = rest_get(export_data_url(rest_port), 20)
if response is not None and check_response(response):
content = json.loads(response.text)
if args.intermediate:
intermediate_results_response = rest_get(metric_data_url(rest_port), REST_TIME_OUT)
if not intermediate_results_response or not check_response(intermediate_results_response):
print_error('Error getting intermediate results.')
return
intermediate_results = groupby_trial_id(json.loads(intermediate_results_response.text))
for record in content:
record['intermediate'] = intermediate_results[record['id']]
if args.type == 'json':
with open(args.path, 'w') as file:
file.write(json.dumps(content))
elif args.type == 'csv':
trial_records = []
for record in content:
formated_record = dict()
if args.intermediate:
formated_record['intermediate'] = '[' + ','.join(record['intermediate']) + ']'
record_value = json.loads(record['value'])
if not isinstance(record_value, (float, int)):
formated_record.update({**record['parameter'], **record_value, **{'id': record['id']}})
else:
formated_record.update({**record['parameter'], **{'reward': record_value, 'id': record['id']}})
trial_records.append(formated_record)
if not trial_records:
print_error('No trial results collected! Please check your trial log...')
exit(0)
with open(args.path, 'w', newline='') as file:
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in trial_records]))
writer.writeheader()
writer.writerows(trial_records)
else:
print_error('Export failed...')
print_error('Unknown type: %s' % args.type)
return
else:
print_error('Restful server is not Running')
print_error('Export failed...')
def search_space_auto_gen(args):
'''dry run trial code to generate search space file'''
......@@ -898,3 +917,4 @@ def load_experiment(args):
# Step6. Cleanup temp data
shutil.rmtree(temp_root_dir)
......@@ -22,6 +22,12 @@ EXPORT_DATA_API = '/export-data'
TENSORBOARD_API = '/tensorboard'
METRIC_DATA_API = '/metric-data'
def metric_data_url(port):
'''get metric_data url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API)
def check_status_url(port):
'''get check_status url'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment