Commit a22d0715 authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 297271153
parent 7b9365dd
......@@ -551,6 +551,8 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
num_gpus = 8
FLAGS.num_gpus = num_gpus
FLAGS.dtype = 'fp16'
# Enable gradient allreduce in fp16
FLAGS.explicit_allreduce = True
FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
......@@ -621,7 +623,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
min_accuracy=0,
max_accuracy=1)
def _benchmark_common(self, num_workers, all_reduce_alg):
def _benchmark_common(self, num_workers, all_reduce_alg,
explicit_allreduce=False):
"""Common to all benchmarks in this class."""
self._setup()
......@@ -637,6 +640,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
num_workers, all_reduce_alg))
FLAGS.train_batch_size = 4 * num_gpus * num_workers
FLAGS.all_reduce_alg = all_reduce_alg
# Enable gradient allreduce in fp16
FLAGS.explicit_allreduce = explicit_allreduce
self._run_and_report_benchmark()
......@@ -650,19 +655,23 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
def benchmark_8_gpu_2_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='ring')
self._benchmark_common(num_workers=2, all_reduce_alg='ring',
explicit_allreduce=True)
def benchmark_8_gpu_2_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
self._benchmark_common(num_workers=2, all_reduce_alg='nccl',
explicit_allreduce=True)
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='ring')
self._benchmark_common(num_workers=8, all_reduce_alg='ring',
explicit_allreduce=True)
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
self._benchmark_common(num_workers=8, all_reduce_alg='nccl',
explicit_allreduce=True)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment