Commit 10b38209 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 309479930
parent 4eda0048
...@@ -25,6 +25,7 @@ from absl import logging ...@@ -25,6 +25,7 @@ from absl import logging
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
from official.benchmark import benchmark_wrappers from official.benchmark import benchmark_wrappers
from official.benchmark import owner_utils
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_keras_main from official.recommendation import ncf_keras_main
...@@ -433,6 +434,17 @@ class NCFKerasBenchmarkReal(NCFKerasBenchmarkBase): ...@@ -433,6 +434,17 @@ class NCFKerasBenchmarkReal(NCFKerasBenchmarkBase):
FLAGS.train_epochs = 1 FLAGS.train_epochs = 1
self._run_and_report_benchmark() self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler')
def benchmark_2x2_tpu_mlir(self):
"""2x2 TPU using CTL with distribution strategy using the MLIR bridge."""
self._setup()
FLAGS.distribution_strategy = 'tpu'
FLAGS.keras_use_ctl = True
FLAGS.num_gpus = 0
FLAGS.train_epochs = 1
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark()
class NCFKerasSynth(NCFKerasBenchmarkBase): class NCFKerasSynth(NCFKerasBenchmarkBase):
"""Benchmark NCF model using synthetic data.""" """Benchmark NCF model using synthetic data."""
......
...@@ -22,6 +22,7 @@ import time ...@@ -22,6 +22,7 @@ import time
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.benchmark import owner_utils
from official.vision.image_classification.resnet import common from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import resnet_ctl_imagenet_main from official.vision.image_classification.resnet import resnet_ctl_imagenet_main
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
...@@ -395,6 +396,16 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -395,6 +396,16 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.dtype = 'bf16' FLAGS.dtype = 'bf16'
self._run_and_report_benchmark() self._run_and_report_benchmark()
@owner_utils.Owner('tf-graph-compiler')
def benchmark_4x4_tpu_bf16_mlir(self):
"""Run resnet model on 4x4 with the MLIR Bridge enabled."""
self._setup()
self._set_df_common()
FLAGS.batch_size = 4096
FLAGS.dtype = 'bf16'
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark()
def benchmark_8x16_tpu_bf16(self): def benchmark_8x16_tpu_bf16(self):
self._setup() self._setup()
self._set_df_common() self._set_df_common()
......
...@@ -23,6 +23,7 @@ import time ...@@ -23,6 +23,7 @@ import time
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.benchmark import benchmark_wrappers from official.benchmark import benchmark_wrappers
from official.benchmark import owner_utils
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark from official.benchmark.perfzero_benchmark import PerfZeroBenchmark
from official.nlp.transformer import misc from official.nlp.transformer import misc
from official.nlp.transformer import transformer_main as transformer_main from official.nlp.transformer import transformer_main as transformer_main
...@@ -728,6 +729,29 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark): ...@@ -728,6 +729,29 @@ class TransformerBigKerasBenchmarkReal(TransformerKerasBenchmark):
total_batch_size=FLAGS.batch_size, total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps) log_steps=FLAGS.log_steps)
@owner_utils.Owner('tf-graph-compiler')
def benchmark_4x4_tpu_mlir(self):
"""Run transformer_big model on 4x4 with the MLIR Bridge enabled."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_4x4_tpu')
FLAGS.train_steps = 300
FLAGS.log_steps = 150
FLAGS.steps_between_evals = 150
FLAGS.distribution_strategy = 'tpu'
FLAGS.static_batch = True
FLAGS.use_ctl = True
FLAGS.batch_size = 24576
FLAGS.max_length = 64
FLAGS.decode_batch_size = 32
FLAGS.decode_max_length = 97
FLAGS.padded_decode = True
FLAGS.enable_checkpointing = False
tf.config.experimental.enable_mlir_bridge()
self._run_and_report_benchmark(
total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment