"integration-tests/models/test_flash_qwen2.py" did not exist on "d5b5bc750fd8d00920cb4cb5e5abe121457d717f"
Commit 983837ff authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 365082527
parent a49c7332
......@@ -315,6 +315,27 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
report_accuracy=False,
ds_type=FLAGS.distribution_strategy)
@owner_utils.Owner('tf-model-garden')
def benchmark_perf_8x16_tpu_bf16_seq128_1k_steps(self):
"""Test bert pretraining with 8x16 TPU for 1000 steps."""
self._setup()
self._specify_common_flags()
self._specify_tpu_common_flags()
FLAGS.train_batch_size = 4096
FLAGS.warmup_steps = 0
FLAGS.num_steps_per_epoch = 1000
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 500
FLAGS.model_dir = self._get_model_dir(
'benchmark_perf_8x16_tpu_bf16_seq128_1k_steps')
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
# Disable accuracy check.
self._run_and_report_benchmark(
summary_path=summary_path,
report_accuracy=False,
ds_type=FLAGS.distribution_strategy)
@owner_utils.Owner('tf-dist-strat')
def benchmark_accuracy_1x8_gpu_fp16_seq128_15k_steps(self):
"""Test bert pretraining with 8 GPU for 15k steps."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment