"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "e87a2251d1f03eff22ac9136fd5cb21797f4b18d"
Commit 7d2d1caa authored by Zongwei Zhou's avatar Zongwei Zhou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 328805239
parent ea70bc22
...@@ -74,12 +74,33 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -74,12 +74,33 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
super(BertPretrainAccuracyBenchmark, self).__init__( super(BertPretrainAccuracyBenchmark, self).__init__(
output_dir=output_dir, tpu=tpu, **kwargs) output_dir=output_dir, tpu=tpu, **kwargs)
def _get_distribution_strategy(self, ds_type='mirrored'):
"""Gets the distribution strategy.
Args:
ds_type: String, the distribution strategy type to be used. Can be
'mirrored', 'multi_worker_mirrored', 'tpu' and 'off'.
Returns:
A `tf.distribute.DistibutionStrategy` object.
"""
if self.tpu or ds_type == 'tpu':
return distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
elif ds_type == 'multi_worker_mirrored':
# Configures cluster spec for multi-worker distribution strategy.
_ = distribution_utils.configure_cluster(FLAGS.worker_hosts,
FLAGS.task_index)
return distribution_utils.get_distribution_strategy(
distribution_strategy=ds_type,
num_gpus=FLAGS.num_gpus,
all_reduce_alg=FLAGS.all_reduce_alg)
@benchmark_wrappers.enable_runtime_flags @benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self, summary_path: str, report_accuracy: bool, def _run_and_report_benchmark(self, summary_path: str, report_accuracy: bool,
ds_type: str): ds_type: str):
"""Runs and reports the benchmark given the provided configuration.""" """Runs and reports the benchmark given the provided configuration."""
distribution = distribution_utils.get_distribution_strategy( distribution = self._get_distribution_strategy(ds_type=ds_type)
distribution_strategy=ds_type, tpu_address=self.tpu)
logging.info('Flags: %s', flags_core.get_nondefault_flags_as_str()) logging.info('Flags: %s', flags_core.get_nondefault_flags_as_str())
start_time_sec = time.time() start_time_sec = time.time()
run_pretraining.run_bert_pretrain( run_pretraining.run_bert_pretrain(
...@@ -279,6 +300,7 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -279,6 +300,7 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
self._setup() self._setup()
self._specify_common_flags() self._specify_common_flags()
self._specify_gpu_common_flags() self._specify_gpu_common_flags()
FLAGS.num_gpus = 8
FLAGS.train_batch_size = 96 FLAGS.train_batch_size = 96
FLAGS.num_steps_per_epoch = 5000 FLAGS.num_steps_per_epoch = 5000
FLAGS.num_train_epochs = 3 FLAGS.num_train_epochs = 3
...@@ -340,7 +362,7 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -340,7 +362,7 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark): class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Resnet50 distributed benchmark tests with multiple workers.""" """Bert pretrain distributed benchmark tests with multiple workers."""
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None):
super(BertPretrainMultiWorkerBenchmark, self).__init__( super(BertPretrainMultiWorkerBenchmark, self).__init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment