Commit 149c6fa3 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 309805531
parent 31b0e518
......@@ -131,6 +131,17 @@ class Unet3DAccuracyBenchmark(keras_benchmark.KerasBenchmark):
dtype='bfloat16',
distribution_strategy='tpu')
def benchmark_4x4_tpu_bf16_mlir(self):
"""Test Keras model with 4x4 TPU, fp16 and MLIR enabled."""
experiment_name = 'benchmark_4x4_tpu_fp16_mlir'
tf.config.experimental.enable_mlir_bridge()
self._setup()
self._set_benchmark_parameters(experiment_name)
self._run_and_report_benchmark(
experiment_name=experiment_name,
dtype='bfloat16',
distribution_strategy='tpu')
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment