"server/vscode:/vscode.git/clone" did not exist on "d789de329a087301d651ee943e0d76e0dbf5ced5"
Commit 8a670c65 authored by Pankaj Kanwar's avatar Pankaj Kanwar Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334221902
parent 4551e0fb
......@@ -294,6 +294,31 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._run_and_report_benchmark()
def benchmark_8_gpu_xla_tf32(self):
"""Tests BERT SQuAD model performance with 8 GPUs with XLA using TF32."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_xla_tf32')
FLAGS.train_batch_size = 32
FLAGS.enable_xla = True
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_8_gpu_xla_fp32_no_tf32(self):
"""Tests BERT SQuAD model performance with 8 GPUs with XLA using FP32."""
self._setup()
tf.config.experimental.enable_tensor_float_32_execution(False)
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_xla_fp32_no_tf32')
FLAGS.train_batch_size = 32
FLAGS.enable_xla = True
FLAGS.loss_scale = 'dynamic'
self._run_and_report_benchmark()
def benchmark_1_gpu_amp(self):
"""Tests BERT SQuAD model performance with 1 GPU with automatic mixed precision."""
......
......@@ -547,7 +547,7 @@ class TransformerKerasBenchmark(TransformerBenchmark):
log_steps=FLAGS.log_steps)
def benchmark_8_gpu(self):
"""Benchmark 8 gpu."""
"""Benchmark 8 gpu. This defaults to using TF32."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.batch_size = self.batch_per_gpu * 8
......@@ -566,7 +566,7 @@ class TransformerKerasBenchmark(TransformerBenchmark):
log_steps=FLAGS.log_steps)
def benchmark_xla_8_gpu(self):
"""Benchmark 8 gpu w/xla."""
"""Benchmark 8 gpu w/xla. This defaults to using TF32."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_xla = True
......@@ -636,6 +636,19 @@ class TransformerKerasBenchmark(TransformerBenchmark):
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
def benchmark_xla_8_gpu_static_batch_fp32_no_tf32(self):
"""Benchmark 8 gpu with static batch w/xla and FP16."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_xla = True
FLAGS.batch_size = self.batch_per_gpu * 8
FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_static_batch_fp32_no_tf32')
FLAGS.static_batch = True
FLAGS.max_length = 64
self._run_and_report_benchmark(total_batch_size=FLAGS.batch_size,
log_steps=FLAGS.log_steps)
class TransformerBaseKerasBenchmarkReal(TransformerKerasBenchmark):
"""Transformer based version real data benchmark tests."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment