Unverified Commit 62d74565 authored by Yan Ni's avatar Yan Ni Committed by GitHub
Browse files

fix trial export (#2303)

parent 2b77ab2d
......@@ -77,6 +77,14 @@ testCases:
kwargs:
expected_result_file: expected_metrics.json
- name: export-float
configFile: test/config/metrics_test/config.yml
config:
maxTrialNum: 1
trialConcurrency: 1
validator:
class: ExportValidator
- name: metrics-dict
configFile: test/config/metrics_test/config_dict_metrics.yml
config:
......@@ -87,6 +95,14 @@ testCases:
kwargs:
expected_result_file: expected_metrics_dict.json
- name: export-dict
configFile: test/config/metrics_test/config_dict_metrics.yml
config:
maxTrialNum: 1
trialConcurrency: 1
validator:
class: ExportValidator
- name: nnicli
configFile: test/config/examples/sklearn-regression.yml
config:
......
......@@ -2,6 +2,8 @@
# Licensed under the MIT license.
import os.path as osp
from os import remove
import subprocess
import json
import requests
import nnicli as nc
......@@ -12,6 +14,24 @@ class ITValidator:
def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
pass
class ExportValidator(ITValidator):
def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
exp_id = osp.split(experiment_dir)[-1]
proc1 = subprocess.run(["nnictl", "experiment", "export", exp_id, "-t", "csv", "-f", "report.csv"])
assert proc1.returncode == 0, '`nnictl experiment export -t csv` failed with code %d' % proc1.returncode
with open("report.csv", 'r') as f:
print('Exported CSV file: \n')
print(''.join(f.readlines()))
print('\n\n')
remove('report.csv')
proc2 = subprocess.run(["nnictl", "experiment", "export", exp_id, "-t", "json", "-f", "report.json"])
assert proc2.returncode == 0, '`nnictl experiment export -t json` failed with code %d' % proc2.returncode
with open("report.json", 'r') as f:
print('Exported JSON file: \n')
print('\n'.join(f.readlines()))
print('\n\n')
remove('report.json')
class MetricsValidator(ITValidator):
def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
......
......@@ -699,12 +699,13 @@ def export_trials_data(args):
content = json.loads(response.text)
trial_records = []
for record in content:
if not isinstance(record['value'], (float, int)):
formated_record = {**record['parameter'], **record['value'], **{'id': record['id']}}
record_value = json.loads(record['value'])
if not isinstance(record_value, (float, int)):
formated_record = {**record['parameter'], **record_value, **{'id': record['id']}}
else:
formated_record = {**record['parameter'], **{'reward': record['value'], 'id': record['id']}}
formated_record = {**record['parameter'], **{'reward': record_value, 'id': record['id']}}
trial_records.append(formated_record)
with open(args.path, 'w') as file:
with open(args.path, 'w', newline='') as file:
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in trial_records]))
writer.writeheader()
writer.writerows(trial_records)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment