"tools/vscode:/vscode.git/clone" did not exist on "9734dcdec4249086b8278cda305f3d4f9f3b9b12"
Commit 7d2d1caa authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 328805239
parent ea70bc22
......@@ -74,12 +74,33 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
super(BertPretrainAccuracyBenchmark, self).__init__(
output_dir=output_dir, tpu=tpu, **kwargs)
def _get_distribution_strategy(self, ds_type='mirrored'):
"""Gets the distribution strategy.
Args:
ds_type: String, the distribution strategy type to be used. Can be
'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
Returns:
A `tf.distribute.DistibutionStrategy` object.
"""
if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
return distribution_utils.get_distribution_strategy(
distribution_strategy=ds_type,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg)
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, summary_path: str, report_accuracy: bool,
ds_type: str):
"""Runs and reports the benchmark given the provided configuration."""
distribution = distribution_utils.get_distribution_strategy(
distribution_strategy=ds_type, tpu_address=self.tpu)
distribution = self._get_distribution_strategy(ds_type=ds_type)
logging.info('Flags: %s', flags_core.get_nondefault_flags_as_str())
start_time_sec = time.time()
run_pretraining.run_bert_pretrain(
......@@ -279,6 +300,7 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._setup()
self._specify_common_flags()
self._specify_gpu_common_flags()
FLAGS.num_gpus = 8
FLAGS.train_batch_size = 96
FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 3
......@@ -340,7 +362,7 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Resnet50 distributed benchmark tests with multiple workers."""
"""Bert pretrain distributed benchmark tests with multiple workers."""
def __init__(self, output_dir=None, default_flags=None):
super(BertPretrainMultiWorkerBenchmark, self).__init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment