Unverified Commit e6cedb89 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Fix metrics nan (#2070)

parent 96afbdc4
......@@ -4,7 +4,6 @@
'use strict';
import * as assert from 'assert';
import * as JSON5 from 'json5';
import { Deferred } from 'ts-deferred';
import * as component from '../common/component';
......@@ -132,7 +131,7 @@ class NNIDataStore implements DataStore {
}
public async storeMetricData(trialJobId: string, data: string): Promise<void> {
const metrics: MetricData = JSON5.parse(data);
const metrics: MetricData = JSON.parse(data);
// REQUEST_PARAMETER is used to request new parameters for multiphase trial job,
// it is not metrics, so it is skipped here.
if (metrics.type === 'REQUEST_PARAMETER') {
......@@ -141,7 +140,7 @@ class NNIDataStore implements DataStore {
}
assert(trialJobId === metrics.trial_job_id);
try {
await this.db.storeMetricData(trialJobId, JSON5.stringify({
await this.db.storeMetricData(trialJobId, JSON.stringify({
trialJobId: metrics.trial_job_id,
parameterId: metrics.parameter_id,
type: metrics.type,
......
......@@ -5,7 +5,6 @@
import * as assert from 'assert';
import * as fs from 'fs';
import * as JSON5 from 'json5';
import * as path from 'path';
import * as sqlite3 from 'sqlite3';
import { Deferred } from 'ts-deferred';
......@@ -203,10 +202,10 @@ class SqlDB implements Database {
public storeMetricData(trialJobId: string, data: string): Promise<void> {
const sql: string = 'insert into MetricData values (?,?,?,?,?,?)';
const json: MetricDataRecord = JSON5.parse(data);
const args: any[] = [Date.now(), json.trialJobId, json.parameterId, json.type, json.sequence, JSON5.stringify(json.data)];
const json: MetricDataRecord = JSON.parse(data);
const args: any[] = [Date.now(), json.trialJobId, json.parameterId, json.type, json.sequence, JSON.stringify(json.data)];
this.log.trace(`storeMetricData: SQL: ${sql}, args: ${JSON5.stringify(args)}`);
this.log.trace(`storeMetricData: SQL: ${sql}, args: ${JSON.stringify(args)}`);
const deferred: Deferred<void> = new Deferred<void>();
this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); });
......
......@@ -17,7 +17,6 @@
"express": "^4.16.3",
"express-joi-validator": "^2.0.0",
"js-base64": "^2.4.9",
"json5": "^2.1.1",
"kubernetes-client": "^6.5.0",
"rx": "^4.1.0",
"sqlite3": "^4.0.2",
......@@ -36,7 +35,6 @@
"@types/express": "^4.16.0",
"@types/glob": "^7.1.1",
"@types/js-base64": "^2.3.1",
"@types/json5": "^0.0.30",
"@types/mocha": "^5.2.5",
"@types/node": "10.12.18",
"@types/request": "^2.47.1",
......
......@@ -157,10 +157,6 @@
version "7.0.3"
resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.3.tgz#bdfd69d61e464dcc81b25159c270d75a73c1a636"
"@types/json5@^0.0.30":
version "0.0.30"
resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.30.tgz#44cb52f32a809734ca562e685c6473b5754a7818"
"@types/mime@*":
version "2.0.0"
resolved "https://registry.yarnpkg.com/@types/mime/-/mime-2.0.0.tgz#5a7306e367c539b9f6543499de8dd519fac37a8b"
......@@ -2380,12 +2376,6 @@ json-stringify-safe@~5.0.1:
version "5.0.1"
resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb"
json5@^2.1.1:
version "2.1.1"
resolved "https://registry.yarnpkg.com/json5/-/json5-2.1.1.tgz#81b6cb04e9ba496f1c7005d07b4368a2638f90b6"
dependencies:
minimist "^1.2.0"
jsonparse@^1.2.0:
version "1.3.1"
resolved "https://registry.yarnpkg.com/jsonparse/-/jsonparse-1.3.1.tgz#3f4dae4a91fac315f71062f8521cc239f1366280"
......
......@@ -127,6 +127,8 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
# metrics value is dumped as json string in trial, so we need to decode it here
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data)
elif data['type'] == MetricType.PERIODICAL:
......
......@@ -33,4 +33,7 @@ def init_params(params):
_params = copy.deepcopy(params)
def get_last_metric():
return json_tricks.loads(_last_metric)
metrics = json_tricks.loads(_last_metric)
metrics['value'] = json_tricks.loads(metrics['value'])
return metrics
......@@ -114,7 +114,7 @@ def report_intermediate_result(metric):
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL',
'sequence': _intermediate_seq,
'value': metric
'value': to_json(metric)
})
_intermediate_seq += 1
platform.send_metric(metric)
......@@ -135,6 +135,6 @@ def report_final_result(metric):
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL',
'sequence': 0,
'value': metric
'value': to_json(metric)
})
platform.send_metric(metric)
......@@ -47,9 +47,9 @@ def _restore_io():
class AssessorTestCase(TestCase):
def test_assessor(self):
_reverse_io()
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":2}')
send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":2}')
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":3}')
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}')
send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}')
send(CommandType.NewTrialJob, 'null')
......
......@@ -59,8 +59,8 @@ class MsgDispatcherTestCase(TestCase):
def test_msg_dispatcher(self):
_reverse_io() # now we are sending to Tuner's incoming stream
send(CommandType.RequestTrialJobs, '2')
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10}')
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11}')
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":"10"}')
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":"11"}')
send(CommandType.UpdateSearchSpace, '{"name":"SS0"}')
send(CommandType.RequestTrialJobs, '1')
send(CommandType.KillTrialJob, 'null')
......
......@@ -56,9 +56,9 @@ def get_metric_results(metrics):
final_result = []
for metric in metrics:
if metric['type'] == 'PERIODICAL':
intermediate_result.append(metric['data'])
intermediate_result.append(json.loads(metric['data']))
elif metric['type'] == 'FINAL':
final_result.append(metric['data'])
final_result.append(json.loads(metric['data']))
print(intermediate_result, final_result)
return [round(float(x),6) for x in intermediate_result], [round(float(x), 6) for x in final_result]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment