Unverified Commit d3610769 authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Use PerfZeroBenchmark and log Flags. (#7052)

parent d62ddcaa
......@@ -18,45 +18,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from official.utils.flags import core as flags_core
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf # pylint: disable=g-bad-import-order
FLAGS = flags.FLAGS
class KerasBenchmark(tf.test.Benchmark):
class KerasBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing."""
local_flags = None
def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
self.output_dir = output_dir
self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {}
if not output_dir:
self.output_dir = '/tmp/'
def _get_model_dir(self, folder_name):
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Sets up and resets flags before each test."""
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
if KerasBenchmark.local_flags is None:
for flag_method in self.flag_methods:
flag_method()
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
# Overrides flag values with defaults for the class of tests.
for k, v in self.default_flags.items():
setattr(FLAGS, k, v)
saved_flag_values = flagsaver.save_flag_values()
KerasBenchmark.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(KerasBenchmark.local_flags)
super(KerasBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods)
def _report_benchmark(self,
stats,
......@@ -102,5 +75,9 @@ class KerasBenchmark(tf.test.Benchmark):
if 'avg_exp_per_second' in stats:
metrics.append({'name': 'avg_exp_per_second',
'value': stats['avg_exp_per_second']})
self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics)
flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark(
iters=-1,
wall_time=wall_time_sec,
metrics=metrics,
extras={'flags': flags_str})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment