Commit 97e04472 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Updating Bert Squad batch size 4. XLA batch size still at 3, as it OOMs at 4.

PiperOrigin-RevId: 290339640
parent abd510a6
...@@ -167,7 +167,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -167,7 +167,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._setup() self._setup()
self.num_gpus = 1 self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad')
FLAGS.train_batch_size = 3 FLAGS.train_batch_size = 4
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -189,7 +189,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -189,7 +189,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._setup() self._setup()
self.num_gpus = 1 self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad')
FLAGS.train_batch_size = 3 FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False) self._run_and_report_benchmark(use_ds=False)
...@@ -200,7 +200,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -200,7 +200,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self.num_gpus = 1 self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_1_gpu_eager_no_dist_strat_squad') 'benchmark_1_gpu_eager_no_dist_strat_squad')
FLAGS.train_batch_size = 3 FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False, run_eagerly=True) self._run_and_report_benchmark(use_ds=False, run_eagerly=True)
...@@ -210,7 +210,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -210,7 +210,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._setup() self._setup()
self.num_gpus = 2 self.num_gpus = 2
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_squad') FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_squad')
FLAGS.train_batch_size = 6 FLAGS.train_batch_size = 8
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -220,7 +220,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -220,7 +220,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._setup() self._setup()
self.num_gpus = 4 self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_squad') FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_squad')
FLAGS.train_batch_size = 12 FLAGS.train_batch_size = 16
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -230,7 +230,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -230,7 +230,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._setup() self._setup()
self.num_gpus = 8 self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
FLAGS.train_batch_size = 24 FLAGS.train_batch_size = 32
self._run_and_report_benchmark() self._run_and_report_benchmark()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment