Commit 5b4eaa3f authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 286409835
parent 83076c16
......@@ -27,12 +27,17 @@ from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
class KerasBenchmark(PerfZeroBenchmark):
"""Base benchmark class with methods to simplify testing."""
def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
tpu=None):
assert tf.version.VERSION.startswith('2.')
super(KerasBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
flag_methods=flag_methods)
flag_methods=flag_methods,
tpu=tpu)
def _report_benchmark(self,
stats,
......@@ -41,7 +46,8 @@ class KerasBenchmark(PerfZeroBenchmark):
top_1_min=None,
log_steps=None,
total_batch_size=None,
warmup=1):
warmup=1,
start_time_sec=None):
"""Report benchmark results by writing to local protobuf file.
Args:
......@@ -52,6 +58,7 @@ class KerasBenchmark(PerfZeroBenchmark):
log_steps: How often the log was created for stats['step_timestamp_log'].
total_batch_size: Global batch-size.
warmup: number of entries in stats['step_timestamp_log'] to ignore.
start_time_sec: the start time of the program in seconds since epoch
"""
metrics = []
......@@ -78,6 +85,13 @@ class KerasBenchmark(PerfZeroBenchmark):
if 'avg_exp_per_second' in stats:
metrics.append({'name': 'avg_exp_per_second',
'value': stats['avg_exp_per_second']})
if start_time_sec and 'step_timestamp_log' in stats:
time_log = stats['step_timestamp_log']
# time_log[0] is recorded at the beginning of the first step.
startup_time = time_log[0].timestamp - start_time_sec
metrics.append({'name': 'startup_time', 'value': startup_time})
flags_str = flags_core.get_nondefault_flags_as_str()
self.report_benchmark(
iters=-1,
......
......@@ -195,13 +195,14 @@ class Resnet50KerasAccuracy(keras_benchmark.KerasBenchmark):
class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 benchmarks."""
def __init__(self, output_dir=None, default_flags=None):
def __init__(self, output_dir=None, default_flags=None, tpu=None):
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
super(Resnet50KerasBenchmarkBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags)
default_flags=default_flags,
tpu=tpu)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, skip_steps=None):
......@@ -218,7 +219,8 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
wall_time_sec,
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps,
warmup=warmup)
warmup=warmup,
start_time_sec=start_time_sec)
def benchmark_1_gpu_no_dist_strat(self):
"""Test Keras model with 1 GPU, no distribution strategy."""
......@@ -798,6 +800,26 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
def benchmark_2x2_tpu_fp16(self):
"""Test Keras model with 2x2 TPU, fp16."""
self._setup()
FLAGS.dtype = 'bf16'
FLAGS.distribution_strategy = 'tpu'
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_fp16')
FLAGS.batch_size = 1024
self._run_and_report_benchmark()
def benchmark_4x4_tpu_fp16(self):
"""Test Keras model with 4x4 TPU, fp16."""
self._setup()
FLAGS.dtype = 'bf16'
FLAGS.distribution_strategy = 'tpu'
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu_fp16')
FLAGS.batch_size = 4096
self._run_and_report_benchmark()
def fill_report_object(self, stats):
super(Resnet50KerasBenchmarkBase, self).fill_report_object(
stats,
......@@ -808,7 +830,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
class Resnet50KerasBenchmarkSynth(Resnet50KerasBenchmarkBase):
"""Resnet50 synthetic benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
......@@ -817,13 +839,13 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasBenchmarkBase):
def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags)
output_dir=output_dir, default_flags=def_flags, tpu=tpu)
class Resnet50KerasBenchmarkReal(Resnet50KerasBenchmarkBase):
"""Resnet50 real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, **kwargs):
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
def_flags = {}
def_flags['skip_eval'] = True
def_flags['report_accuracy_metrics'] = False
......@@ -832,7 +854,7 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasBenchmarkBase):
def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags)
output_dir=output_dir, default_flags=def_flags, tpu=tpu)
class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment