Commit 063e258b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 360809215
parent 06d41881
...@@ -635,6 +635,15 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -635,6 +635,15 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=128) per_replica_batch_size=128)
def benchmark_2x2_tpu(self):
"""Test Keras model with 2x2 TPU."""
self._setup()
self._run_and_report_benchmark(
experiment_name='benchmark_2x2_tpu',
num_tpus=8,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_4x4_tpu_bf16(self): def benchmark_4x4_tpu_bf16(self):
"""Test Keras model with 4x4 TPU, bf16.""" """Test Keras model with 4x4 TPU, bf16."""
self._setup() self._setup()
...@@ -645,6 +654,15 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -645,6 +654,15 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=128) per_replica_batch_size=128)
def benchmark_4x4_tpu(self):
"""Test Keras model with 4x4 TPU."""
self._setup()
self._run_and_report_benchmark(
experiment_name='benchmark_4x4_tpu',
num_tpus=32,
distribution_strategy='tpu',
per_replica_batch_size=128)
def benchmark_2x2_tpu_bf16_mlir(self): def benchmark_2x2_tpu_bf16_mlir(self):
"""Test Keras model with 2x2 TPU, bf16.""" """Test Keras model with 2x2 TPU, bf16."""
self._setup() self._setup()
...@@ -677,6 +695,15 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -677,6 +695,15 @@ class KerasClassifierBenchmarkBase(keras_benchmark.KerasBenchmark):
distribution_strategy='tpu', distribution_strategy='tpu',
per_replica_batch_size=64) per_replica_batch_size=64)
def benchmark_8x8_tpu(self):
"""Test Keras model with 8x8 TPU."""
self._setup()
self._run_and_report_benchmark(
experiment_name='benchmark_8x8_tpu',
num_tpus=128,
distribution_strategy='tpu',
per_replica_batch_size=64)
def fill_report_object(self, stats): def fill_report_object(self, stats):
super(KerasClassifierBenchmarkBase, self).fill_report_object( super(KerasClassifierBenchmarkBase, self).fill_report_object(
stats, stats,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment