Commit 27e86174 authored by Toby Boyd's avatar Toby Boyd Committed by Dong Lin
Browse files

Fix accuracy name (#6179)

parent c9285547
...@@ -19,8 +19,8 @@ from __future__ import division ...@@ -19,8 +19,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import json import json
import time
import os import os
import time
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
...@@ -88,7 +88,7 @@ class EstimatorCifar10BenchmarkTests(tf.test.Benchmark): ...@@ -88,7 +88,7 @@ class EstimatorCifar10BenchmarkTests(tf.test.Benchmark):
self._run_and_report_benchmark() self._run_and_report_benchmark()
def unit_test(self): def unit_test(self):
"""A lightweigth test that can finish quickly""" """A lightweight test that can finish quickly."""
self._setup() self._setup()
flags.FLAGS.num_gpus = 1 flags.FLAGS.num_gpus = 1
flags.FLAGS.data_dir = DATA_DIR flags.FLAGS.data_dir = DATA_DIR
...@@ -108,7 +108,7 @@ class EstimatorCifar10BenchmarkTests(tf.test.Benchmark): ...@@ -108,7 +108,7 @@ class EstimatorCifar10BenchmarkTests(tf.test.Benchmark):
iters=stats['global_step'], iters=stats['global_step'],
wall_time=wall_time_sec, wall_time=wall_time_sec,
extras={ extras={
'accuracy': 'accuracy_top_1':
self._json_description(stats['accuracy'].item(), priority=0), self._json_description(stats['accuracy'].item(), priority=0),
'accuracy_top_5': 'accuracy_top_5':
self._json_description(stats['accuracy_top_5'].item()), self._json_description(stats['accuracy_top_5'].item()),
...@@ -119,7 +119,7 @@ class EstimatorCifar10BenchmarkTests(tf.test.Benchmark): ...@@ -119,7 +119,7 @@ class EstimatorCifar10BenchmarkTests(tf.test.Benchmark):
priority=None, priority=None,
min_value=None, min_value=None,
max_value=None): max_value=None):
"""Get a json-formatted string describing the attributes for a metric""" """Get a json-formatted string describing the attributes for a metric."""
attributes = {} attributes = {}
attributes['value'] = value attributes['value'] = value
......
...@@ -65,7 +65,7 @@ class KerasBenchmark(tf.test.Benchmark): ...@@ -65,7 +65,7 @@ class KerasBenchmark(tf.test.Benchmark):
log_steps=None, log_steps=None,
total_batch_size=None, total_batch_size=None,
warmup=1): warmup=1):
"""Report benchmark results by writing to local protobuf file """Report benchmark results by writing to local protobuf file.
Args: Args:
stats: dict returned from keras models with known entries. stats: dict returned from keras models with known entries.
...@@ -79,7 +79,7 @@ class KerasBenchmark(tf.test.Benchmark): ...@@ -79,7 +79,7 @@ class KerasBenchmark(tf.test.Benchmark):
extras = {} extras = {}
if 'accuracy_top_1' in stats: if 'accuracy_top_1' in stats:
extras['accuracy'] = self._json_description( extras['accuracy_top_1'] = self._json_description(
stats['accuracy_top_1'], stats['accuracy_top_1'],
priority=0, priority=0,
min_value=top_1_min, min_value=top_1_min,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment