Commit a22d0715 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 297271153
parent 7b9365dd
...@@ -551,6 +551,8 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase): ...@@ -551,6 +551,8 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
num_gpus = 8 num_gpus = 8
FLAGS.num_gpus = num_gpus FLAGS.num_gpus = num_gpus
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
# Enable gradient allreduce in fp16
FLAGS.explicit_allreduce = True
FLAGS.enable_xla = False FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored' FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
...@@ -621,7 +623,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase): ...@@ -621,7 +623,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
min_accuracy=0, min_accuracy=0,
max_accuracy=1) max_accuracy=1)
def _benchmark_common(self, num_workers, all_reduce_alg): def _benchmark_common(self, num_workers, all_reduce_alg,
explicit_allreduce=False):
"""Common to all benchmarks in this class.""" """Common to all benchmarks in this class."""
self._setup() self._setup()
...@@ -637,6 +640,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase): ...@@ -637,6 +640,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
num_workers, all_reduce_alg)) num_workers, all_reduce_alg))
FLAGS.train_batch_size = 4 * num_gpus * num_workers FLAGS.train_batch_size = 4 * num_gpus * num_workers
FLAGS.all_reduce_alg = all_reduce_alg FLAGS.all_reduce_alg = all_reduce_alg
# Enable gradient allreduce in fp16
FLAGS.explicit_allreduce = explicit_allreduce
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -650,19 +655,23 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase): ...@@ -650,19 +655,23 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
def benchmark_8_gpu_2_workers_fp16_ring_tweaked(self): def benchmark_8_gpu_2_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce.""" """8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='ring') self._benchmark_common(num_workers=2, all_reduce_alg='ring',
explicit_allreduce=True)
def benchmark_8_gpu_2_workers_fp16_nccl_tweaked(self): def benchmark_8_gpu_2_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce.""" """8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='nccl') self._benchmark_common(num_workers=2, all_reduce_alg='nccl',
explicit_allreduce=True)
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self): def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce.""" """8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='ring') self._benchmark_common(num_workers=8, all_reduce_alg='ring',
explicit_allreduce=True)
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self): def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce.""" """8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='nccl') self._benchmark_common(num_workers=8, all_reduce_alg='nccl',
explicit_allreduce=True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment