Commit 7c08a2d0 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 322884473
parent 1cbc2345
......@@ -299,20 +299,21 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
return os.path.join(self.output_dir, folder_name)
class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 (classifier_trainer) benchmarks."""
class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Classifier Trainer benchmarks."""
def __init__(self, output_dir=None, default_flags=None,
def __init__(self, model, output_dir=None, default_flags=None,
tpu=None, dataset_builder='records', train_epochs=1,
train_steps=110, data_dir=None):
flag_methods = [classifier_trainer.define_classifier_flags]
self.model = model
self.dataset_builder = dataset_builder
self.train_epochs = train_epochs
self.train_steps = train_steps
self.data_dir = data_dir
super(Resnet50KerasClassifierBenchmarkBase, self).__init__(
super(KerasClassifierBenchmarkBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags,
......@@ -337,7 +338,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
dataset_num_private_threads: Optional[int] = None,
loss_scale: Optional[str] = None):
"""Runs and reports the benchmark given the provided configuration."""
FLAGS.model_type = 'resnet'
FLAGS.model_type = self.model
FLAGS.dataset = 'imagenet'
FLAGS.mode = 'train_and_eval'
FLAGS.data_dir = self.data_dir
......@@ -372,7 +373,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
# input skip_steps.
warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps
super(Resnet50KerasClassifierBenchmarkBase, self)._report_benchmark(
super(KerasClassifierBenchmarkBase, self)._report_benchmark(
stats,
wall_time_sec,
total_batch_size=total_batch_size,
......@@ -599,8 +600,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='mirrored',
per_replica_batch_size=256,
gpu_thread_mode='gpu_private',
dataset_num_private_threads=48,
steps=310)
dataset_num_private_threads=48)
def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
"""Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
......@@ -636,6 +636,28 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_2x2_tpu_bf16_mlir(self):
"""Test Keras model with 2x2 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_2x2_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=8,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_4x4_tpu_bf16_mlir(self):
"""Test Keras model with 4x4 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_4x4_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=32,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_8x8_tpu_bf16(self):
"""Test Keras model with 8x8 TPU, bf16."""
self._setup()
......@@ -647,7 +669,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
per_replica_batch_size=64)
def fill_report_object(self, stats):
super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object(
super(KerasClassifierBenchmarkBase, self).fill_report_object(
stats,
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
......@@ -1086,7 +1108,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
log_steps=FLAGS.log_steps)
class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
class Resnet50KerasBenchmarkSynth(KerasClassifierBenchmarkBase):
"""Resnet50 synthetic benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
......@@ -1094,11 +1116,11 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu,
model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='synthetic', train_epochs=1, train_steps=110)
class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
class Resnet50KerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""Resnet50 real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
......@@ -1107,11 +1129,25 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu,
model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir)
class EfficientNetKerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""EfficientNet real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
data_dir = os.path.join(root_data_dir, 'imagenet')
def_flags = {}
def_flags['log_steps'] = 10
super(EfficientNetKerasBenchmarkReal, self).__init__(
model='efficientnet', output_dir=output_dir, default_flags=def_flags,
tpu=tpu, dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir)
class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
"""Resnet50 real data (stored in remote storage) benchmark tests."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment