Commit c5ad244e authored by Yanhui Liang's avatar Yanhui Liang Committed by A. Unique TensorFlower
Browse files

Modify `_get_distribution_strategy` for multi-worker benchmark.

PiperOrigin-RevId: 297016418
parent fb35d6be
......@@ -82,15 +82,27 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
with tf.io.gfile.GFile(predictions_file, 'r') as reader:
return json.load(reader)
def _get_distribution_strategy(self, use_ds=True):
"""Gets the distribution strategy."""
if self.tpu:
def _get_distribution_strategy(self, ds_type='mirrored'):
"""Gets the distribution strategy.
Args:
ds_type: String, the distribution strategy type to be used. Can be
'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
Returns:
A `tf.distribute.DistibutionStrategy` object.
"""
if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
else:
return distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
return distribution_utils.get_distribution_strategy(
distribution_strategy=ds_type,
num_gpus=self.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg)
def _init_gpu_and_data_threads(self):
"""Set env variables before any TF calls."""
......@@ -102,12 +114,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
datasets_num_private_threads=FLAGS.datasets_num_private_threads)
@flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False):
"""Runs BERT SQuAD training."""
def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
"""Runs BERT SQuAD training. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds)
strategy = self._get_distribution_strategy(ds_type)
run_squad.train_squad(
strategy=strategy,
......@@ -116,12 +128,12 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
custom_callbacks=[self.timer_callback])
@flagsaver.flagsaver
def _evaluate_squad(self, use_ds=True):
"""Runs BERT SQuAD evaluation."""
def _evaluate_squad(self, ds_type='mirrored'):
"""Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(use_ds)
strategy = self._get_distribution_strategy(ds_type)
run_squad.predict_squad(strategy=strategy, input_meta_data=input_meta_data)
......@@ -157,15 +169,15 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
run_eagerly=False,
ds_type='mirrored'):
"""Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4:
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
else:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
......@@ -217,7 +229,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_no_dist_strat_squad')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False)
self._run_and_report_benchmark(ds_type='off')
def benchmark_1_gpu_eager_no_dist_strat(self):
"""Tests BERT SQuAD model performance with 1 GPU with eager execution."""
......@@ -228,7 +240,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
'benchmark_1_gpu_eager_no_dist_strat_squad')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False, run_eagerly=True)
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
def benchmark_2_gpu(self):
"""Tests BERT SQuAD model performance with 2 GPUs."""
......@@ -420,12 +432,12 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
use_ds=True,
run_eagerly=False):
run_eagerly=False,
ds_type='mirrored'):
"""Runs the benchmark and reports various metrics."""
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._evaluate_squad()
self._train_squad(run_eagerly=run_eagerly, ds_type=ds_type)
self._evaluate_squad(ds_type=ds_type)
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
......@@ -445,7 +457,7 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad_eager')
FLAGS.train_batch_size = 4
self._run_and_report_benchmark(use_ds=False, run_eagerly=True)
self._run_and_report_benchmark(ds_type='off', run_eagerly=True)
def benchmark_8_gpu(self):
"""Tests BERT SQuAD model accuracy with 8 GPUs."""
......@@ -518,8 +530,9 @@ class BertSquadMultiWorkerAccuracy(BertSquadBenchmarkBase):
run_eagerly=False):
"""Runs the benchmark and reports various metrics."""
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._evaluate_squad()
self._train_squad(run_eagerly=run_eagerly,
ds_type='multi_worker_mirrored')
self._evaluate_squad(ds_type='multi_worker_mirrored')
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
......@@ -595,7 +608,8 @@ class BertSquadMultiWorkerBenchmark(BertSquadBenchmarkBase):
else:
FLAGS.input_meta_data_path = SQUAD_FULL_INPUT_META_DATA_PATH
start_time_sec = time.time()
self._train_squad(use_ds=use_ds, run_eagerly=run_eagerly)
self._train_squad(run_eagerly=run_eagerly,
ds_type='multi_worker_mirrored')
wall_time_sec = time.time() - start_time_sec
summary = self._read_training_summary_from_file()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment