Commit d4ef56b1 authored by Yanhui Liang's avatar Yanhui Liang Committed by A. Unique TensorFlower
Browse files

Add multi-worker dist-strat benchmark/accuracy for BERT SQUAD model.

PiperOrigin-RevId: 291851815
parent a7d1b2b3
...@@ -436,5 +436,153 @@ class BertSquadAccuracy(BertSquadBenchmarkBase): ...@@ -436,5 +436,153 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
self._run_and_report_benchmark() self._run_and_report_benchmark()
class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
"""BERT SQuAD distributed accuracy tests with multiple workers."""
def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadMultiWorkerAccuracy, self).__init__(
output_dir=output_dir, tpu=tpu)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
super(BertSquadMultiWorkerAccuracy, self)._setup()
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.init_checkpoint = PRETRAINED_CHECKPOINT_PATH
FLAGS.num_train_epochs = 2
FLAGS.steps_per_loop = 1
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._evaluate_squad()
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
summary['eval_metrics'] = self.eval_metrics
super(BertSquadMultiWorkerAccuracy, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=0.900,
max_accuracy=0.920)
def _benchmark_common(self, num_workers, all_reduce_alg):
"""Common to all benchmarks in this class."""
self._setup()
num_gpus = 8
FLAGS.num_gpus = num_gpus
FLAGS.dtype = 'fp16'
FLAGS.enable_eager = True,
FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir(
'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
num_workers, all_reduce_alg))
FLAGS.train_batch_size = 4 * num_gpus * num_workers
FLAGS.all_reduce_alg = all_reduce_alg
self._run_and_report_benchmark()
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='ring')
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
"""BERT SQuAD distributed benchmark tests with multiple workers."""
def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertSquadMultiWorkerBenchmark, self).__init__(
output_dir=output_dir, tpu=tpu)
def _setup(self):
"""Sets up the benchmark and SQuAD flags."""
super(BertSquadMultiWorkerBenchmark, self)._setup()
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1
FLAGS.steps_per_loop = 1
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
summary['start_time_sec'] = start_time_sec
super(BertSquadMultiWorkerBenchmark, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=0,
max_accuracy=1)
def _benchmark_common(self, num_workers, all_reduce_alg):
"""Common to all benchmarks in this class."""
self._setup()
num_gpus = 8
FLAGS.num_gpus = num_gpus
FLAGS.dtype = 'fp16'
FLAGS.enable_eager = True
FLAGS.enable_xla = False
FLAGS.distribution_strategy = 'multi_worker_mirrored'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 32
FLAGS.model_dir = self._get_model_dir(
'benchmark_8_gpu_{}_worker_fp16_{}_tweaked'.format(
num_workers, all_reduce_alg))
FLAGS.train_batch_size = 4 * num_gpus * num_workers
FLAGS.all_reduce_alg = all_reduce_alg
self._run_and_report_benchmark()
def benchmark_8_gpu_1_worker_fp16_ring_tweaked(self):
"""8 GPUs per worker, 1 worker, fp16, ring all-reduce."""
self._benchmark_common(num_workers=1, all_reduce_alg='ring')
def benchmark_8_gpu_1_worker_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 1 worker, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=1, all_reduce_alg='nccl')
def benchmark_8_gpu_2_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='ring')
def benchmark_8_gpu_2_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 2 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=2, all_reduce_alg='nccl')
def benchmark_8_gpu_8_workers_fp16_ring_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, ring all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='ring')
def benchmark_8_gpu_8_workers_fp16_nccl_tweaked(self):
"""8 GPUs per worker, 8 workers, fp16, nccl all-reduce."""
self._benchmark_common(num_workers=8, all_reduce_alg='nccl')
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -35,6 +35,7 @@ def define_common_bert_flags(): ...@@ -35,6 +35,7 @@ def define_common_bert_flags():
export_dir=False, export_dir=False,
distribution_strategy=True, distribution_strategy=True,
run_eagerly=True) run_eagerly=True)
flags_core.define_distribution()
flags.DEFINE_string('bert_config_file', None, flags.DEFINE_string('bert_config_file', None,
'Bert configuration file to define core bert layers.') 'Bert configuration file to define core bert layers.')
flags.DEFINE_string( flags.DEFINE_string(
......
...@@ -406,9 +406,14 @@ def main(_): ...@@ -406,9 +406,14 @@ def main(_):
export_squad(FLAGS.model_export_path, input_meta_data) export_squad(FLAGS.model_export_path, input_meta_data)
return return
# Configures cluster spec for multi-worker distribution strategy.
if FLAGS.num_gpus > 0:
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.distribution_strategy, distribution_strategy=FLAGS.distribution_strategy,
num_gpus=FLAGS.num_gpus, num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg,
tpu_address=FLAGS.tpu) tpu_address=FLAGS.tpu)
if FLAGS.mode in ('train', 'train_and_predict'): if FLAGS.mode in ('train', 'train_and_predict'):
train_squad(strategy, input_meta_data) train_squad(strategy, input_meta_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment