Commit 50fe1963 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 330007126
parent 16e2edce
...@@ -107,10 +107,13 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase): ...@@ -107,10 +107,13 @@ class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
strategy=distribution, custom_callbacks=self.timer_callback) strategy=distribution, custom_callbacks=self.timer_callback)
wall_time_sec = time.time() - start_time_sec wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(summary_path, 'rb') as reader: # For GPU multi-worker, the summary text file is only generated on chief
summary = json.loads(reader.read().decode('utf-8')) # (metrics aggregated), so only chief has to report the result.
self._report_benchmark(summary, start_time_sec, wall_time_sec, if tf.io.gfile.exists(summary_path):
report_accuracy, ds_type) with tf.io.gfile.GFile(summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
self._report_benchmark(summary, start_time_sec, wall_time_sec,
report_accuracy, ds_type)
def _report_benchmark(self, summary, start_time_sec, wall_time_sec, def _report_benchmark(self, summary, start_time_sec, wall_time_sec,
report_accuracy, ds_type): report_accuracy, ds_type):
...@@ -429,7 +432,7 @@ class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark): ...@@ -429,7 +432,7 @@ class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Test bert pretraining with 1x8 GPU for 200 steps.""" """Test bert pretraining with 1x8 GPU for 200 steps."""
self._setup() self._setup()
self._specify_common_flags() self._specify_common_flags()
self._specify_gpu_common_flags() self._specify_gpu_mwms_flags()
FLAGS.num_steps_per_epoch = 200 FLAGS.num_steps_per_epoch = 200
FLAGS.num_train_epochs = 1 FLAGS.num_train_epochs = 1
FLAGS.train_batch_size = 96 * 1 FLAGS.train_batch_size = 96 * 1
...@@ -449,7 +452,7 @@ class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark): ...@@ -449,7 +452,7 @@ class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Test bert pretraining with 2x8 GPU for 200 steps.""" """Test bert pretraining with 2x8 GPU for 200 steps."""
self._setup() self._setup()
self._specify_common_flags() self._specify_common_flags()
self._specify_gpu_common_flags() self._specify_gpu_mwms_flags()
FLAGS.num_steps_per_epoch = 200 FLAGS.num_steps_per_epoch = 200
FLAGS.num_train_epochs = 1 FLAGS.num_train_epochs = 1
FLAGS.train_batch_size = 96 * 2 FLAGS.train_batch_size = 96 * 2
...@@ -469,7 +472,7 @@ class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark): ...@@ -469,7 +472,7 @@ class BertPretrainMultiWorkerBenchmark(BertPretrainAccuracyBenchmark):
"""Test bert pretraining with 8x8 GPU for 200 steps.""" """Test bert pretraining with 8x8 GPU for 200 steps."""
self._setup() self._setup()
self._specify_common_flags() self._specify_common_flags()
self._specify_gpu_common_flags() self._specify_gpu_mwms_flags()
FLAGS.num_steps_per_epoch = 200 FLAGS.num_steps_per_epoch = 200
FLAGS.num_train_epochs = 1 FLAGS.num_train_epochs = 1
FLAGS.train_batch_size = 96*8 FLAGS.train_batch_size = 96*8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment