Commit 7c08a2d0 authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 322884473
parent 1cbc2345
...@@ -299,20 +299,21 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark): ...@@ -299,20 +299,21 @@ class MobilenetV1KerasAccuracy(keras_benchmark.KerasBenchmark):
return os.path.join(self.output_dir, folder_name) return os.path.join(self.output_dir, folder_name)
class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
"""Resnet50 (classifier_trainer) benchmarks.""" """Classifier Trainer benchmarks."""
def __init__(self, output_dir=None, default_flags=None, def __init__(self, model, output_dir=None, default_flags=None,
tpu=None, dataset_builder='records', train_epochs=1, tpu=None, dataset_builder='records', train_epochs=1,
train_steps=110, data_dir=None): train_steps=110, data_dir=None):
flag_methods = [classifier_trainer.define_classifier_flags] flag_methods = [classifier_trainer.define_classifier_flags]
self.model = model
self.dataset_builder = dataset_builder self.dataset_builder = dataset_builder
self.train_epochs = train_epochs self.train_epochs = train_epochs
self.train_steps = train_steps self.train_steps = train_steps
self.data_dir = data_dir self.data_dir = data_dir
super(Resnet50KerasClassifierBenchmarkBase, self).__init__( super(KerasClassifierBenchmarkBase, self).__init__(
output_dir=output_dir, output_dir=output_dir,
flag_methods=flag_methods, flag_methods=flag_methods,
default_flags=default_flags, default_flags=default_flags,
...@@ -337,7 +338,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -337,7 +338,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
dataset_num_private_threads: Optional[int] = None, dataset_num_private_threads: Optional[int] = None,
loss_scale: Optional[str] = None): loss_scale: Optional[str] = None):
"""Runs and reports the benchmark given the provided configuration.""" """Runs and reports the benchmark given the provided configuration."""
FLAGS.model_type = 'resnet' FLAGS.model_type = self.model
FLAGS.dataset = 'imagenet' FLAGS.dataset = 'imagenet'
FLAGS.mode = 'train_and_eval' FLAGS.mode = 'train_and_eval'
FLAGS.data_dir = self.data_dir FLAGS.data_dir = self.data_dir
...@@ -372,7 +373,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -372,7 +373,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
# input skip_steps. # input skip_steps.
warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps warmup = (skip_steps or (self.train_steps - 100)) // FLAGS.log_steps
super(Resnet50KerasClassifierBenchmarkBase, self)._report_benchmark( super(KerasClassifierBenchmarkBase, self)._report_benchmark(
stats, stats,
wall_time_sec, wall_time_sec,
total_batch_size=total_batch_size, total_batch_size=total_batch_size,
...@@ -599,8 +600,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -599,8 +600,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='mirrored', distribution_strategy='mirrored',
per_replica_batch_size=256, per_replica_batch_size=256,
gpu_thread_mode='gpu_private', gpu_thread_mode='gpu_private',
dataset_num_private_threads=48, dataset_num_private_threads=48)
steps=310)
def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self): def benchmark_xla_8_gpu_fp16_dynamic_tweaked(self):
"""Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16.""" """Tests Keras model with config tuning, XLA, 8 GPUs and dynamic fp16."""
...@@ -636,6 +636,28 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -636,6 +636,28 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=128) per_replica_batch_size=128)
def benchmark_2x2_tpu_bf16_mlir(self):
"""Test Keras model with 2x2 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_2x2_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=8,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_4x4_tpu_bf16_mlir(self):
"""Test Keras model with 4x4 TPU, bf16."""
self._setup()
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
experiment_name='benchmark_4x4_tpu_bf16_mlir',
dtype='bfloat16',
num_tpus=32,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_8x8_tpu_bf16(self): def benchmark_8x8_tpu_bf16(self):
"""Test Keras model with 8x8 TPU, bf16.""" """Test Keras model with 8x8 TPU, bf16."""
self._setup() self._setup()
...@@ -647,7 +669,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -647,7 +669,7 @@ class Resnet50KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
per_replica_batch_size=64) per_replica_batch_size=64)
def fill_report_object(self, stats): def fill_report_object(self, stats):
super(Resnet50KerasClassifierBenchmarkBase, self).fill_report_object( super(KerasClassifierBenchmarkBase, self).fill_report_object(
stats, stats,
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps) log_steps=FLAGS.log_steps)
...@@ -1086,7 +1108,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -1086,7 +1108,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
log_steps=FLAGS.log_steps) log_steps=FLAGS.log_steps)
class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase): class Resnet50KerasBenchmarkSynth(KerasClassifierBenchmarkBase):
"""Resnet50 synthetic benchmark tests.""" """Resnet50 synthetic benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
...@@ -1094,11 +1116,11 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase): ...@@ -1094,11 +1116,11 @@ class Resnet50KerasBenchmarkSynth(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkSynth, self).__init__( super(Resnet50KerasBenchmarkSynth, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu, model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='synthetic', train_epochs=1, train_steps=110) dataset_builder='synthetic', train_epochs=1, train_steps=110)
class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase): class Resnet50KerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""Resnet50 real data benchmark tests.""" """Resnet50 real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs): def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
...@@ -1107,11 +1129,25 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase): ...@@ -1107,11 +1129,25 @@ class Resnet50KerasBenchmarkReal(Resnet50KerasClassifierBenchmarkBase):
def_flags['log_steps'] = 10 def_flags['log_steps'] = 10
super(Resnet50KerasBenchmarkReal, self).__init__( super(Resnet50KerasBenchmarkReal, self).__init__(
output_dir=output_dir, default_flags=def_flags, tpu=tpu, model='resnet', output_dir=output_dir, default_flags=def_flags, tpu=tpu,
dataset_builder='records', train_epochs=1, train_steps=110, dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir) data_dir=data_dir)
class EfficientNetKerasBenchmarkReal(KerasClassifierBenchmarkBase):
"""EfficientNet real data benchmark tests."""
def __init__(self, output_dir=None, root_data_dir=None, tpu=None, **kwargs):
data_dir = os.path.join(root_data_dir, 'imagenet')
def_flags = {}
def_flags['log_steps'] = 10
super(EfficientNetKerasBenchmarkReal, self).__init__(
model='efficientnet', output_dir=output_dir, default_flags=def_flags,
tpu=tpu, dataset_builder='records', train_epochs=1, train_steps=110,
data_dir=data_dir)
class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase): class Resnet50KerasBenchmarkRemoteData(Resnet50KerasBenchmarkBase):
"""Resnet50 real data (stored in remote storage) benchmark tests.""" """Resnet50 real data (stored in remote storage) benchmark tests."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment