Commit 50fe1963 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 330007126
parent 16e2edce
......@@ -107,10 +107,13 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
strategy=distribution, custom_callbacks=self.timer_callback)
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
self._report_benchmark(summary, start_time_sec, wall_time_sec,
report_accuracy, ds_type)
# For GPU multi-worker, the summary text file is only generated on chief
# (metrics aggregated), so only chief has to report the result.
if tf.io.gfile.exists(summary_path):
with tf.io.gfile.GFile(summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
self._report_benchmark(summary, start_time_sec, wall_time_sec,
report_accuracy, ds_type)
def _report_benchmark(self, summary, start_time_sec, wall_time_sec,
report_accuracy, ds_type):
......@@ -429,7 +432,7 @@ class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Test bert pretraining with 1x8 GPU for 200 steps."""
self._setup()
self._specify_common_flags()
self._specify_gpu_common_flags()
self._specify_gpu_mwms_flags()
FLAGS.num_steps_per_epoch = 200
FLAGS.num_train_epochs = 1
FLAGS.train_batch_size = 96 * 1
......@@ -449,7 +452,7 @@ class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Test bert pretraining with 2x8 GPU for 200 steps."""
self._setup()
self._specify_common_flags()
self._specify_gpu_common_flags()
self._specify_gpu_mwms_flags()
FLAGS.num_steps_per_epoch = 200
FLAGS.num_train_epochs = 1
FLAGS.train_batch_size = 96 * 2
......@@ -469,7 +472,7 @@ class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Test bert pretraining with 8x8 GPU for 200 steps."""
self._setup()
self._specify_common_flags()
self._specify_gpu_common_flags()
self._specify_gpu_mwms_flags()
FLAGS.num_steps_per_epoch = 200
FLAGS.num_train_epochs = 1
FLAGS.train_batch_size = 96*8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment