Commit 2741cc5f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

BERT enable gpu thread mode as "gpu_private" for 8-GPU runs.

PiperOrigin-RevId: 292369407
parent 2c48c0dd
......@@ -32,6 +32,7 @@ from official.benchmark import bert_benchmark_utils as benchmark_utils
from official.benchmark import squad_evaluate_v1_1
from official.nlp.bert import run_squad
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.testing import benchmark_wrappers
......@@ -90,10 +91,22 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
def _init_gpu_and_data_threads(self):
"""Set env variables before any TF calls."""
if FLAGS.tf_gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=FLAGS.per_gpu_thread_count,
gpu_thread_mode=FLAGS.tf_gpu_thread_mode,
num_gpus=self.num_gpus,
datasets_num_private_threads=FLAGS.datasets_num_private_threads)
@flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False):
"""Runs BERT SQuAD training."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds)
......@@ -107,6 +120,7 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
def _evaluate_squad(self, use_ds=True):
"""Runs BERT SQuAD evaluation."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds)
......@@ -231,6 +245,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
FLAGS.train_batch_size = 32
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -292,6 +307,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -328,6 +344,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -400,6 +417,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
FLAGS.train_batch_size = 24
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -412,6 +430,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.train_batch_size = 32
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......@@ -423,6 +442,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad_xla')
FLAGS.train_batch_size = 32
FLAGS.enable_xla = True
FLAGS.tf_gpu_thread_mode = 'gpu_private'
self._run_and_report_benchmark()
......
......@@ -83,6 +83,8 @@ def define_common_bert_flags():
loss_scale=True,
all_reduce_alg=True,
num_packs=False,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
enable_xla=True,
fp16_implementation=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment