Commit 0612e190 authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 286087529
parent 7e67dbbc
...@@ -18,17 +18,16 @@ from __future__ import absolute_import ...@@ -18,17 +18,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import time import time
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import numpy as np import numpy as np
from absl import flags from absl import flags
from absl.testing import flagsaver
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -59,34 +58,20 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback): ...@@ -59,34 +58,20 @@ class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
return self.batch_start_times[0] - program_start_time return self.batch_start_times[0] - program_start_time
class BertBenchmarkBase(tf.test.Benchmark): class BertBenchmarkBase(PerfZeroBenchmark):
"""Base class to hold methods common to test classes.""" """Base class to hold methods common to test classes."""
local_flags = None local_flags = None
def __init__(self, output_dir=None): def __init__(self, output_dir=None):
super(BertBenchmarkBase, self).__init__(output_dir=output_dir)
self.num_gpus = 8 self.num_gpus = 8
if not output_dir:
output_dir = '/tmp'
self.output_dir = output_dir
self.timer_callback = None self.timer_callback = None
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
def _setup(self): def _setup(self):
"""Sets up and resets flags before each test.""" """Sets up and resets flags before each test."""
super(BertBenchmarkBase, self)._setup()
self.timer_callback = BenchmarkTimerCallback() self.timer_callback = BenchmarkTimerCallback()
if BertBenchmarkBase.local_flags is None:
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
saved_flag_values = flagsaver.save_flag_values()
BertBenchmarkBase.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(BertBenchmarkBase.local_flags)
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy): def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
"""Report benchmark results by writing to local protobuf file. """Report benchmark results by writing to local protobuf file.
......
...@@ -52,6 +52,10 @@ FLAGS = flags.FLAGS ...@@ -52,6 +52,10 @@ FLAGS = flags.FLAGS
class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None, tpu=None):
super(BertSquadBenchmarkBase, self).__init__(output_dir=output_dir)
self.tpu = tpu
def _read_training_summary_from_file(self): def _read_training_summary_from_file(self):
"""Reads the training summary from a file.""" """Reads the training summary from a file."""
summary_path = os.path.join(FLAGS.model_dir, summary_path = os.path.join(FLAGS.model_dir,
...@@ -78,9 +82,13 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -78,9 +82,13 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
def _get_distribution_strategy(self, use_ds=True): def _get_distribution_strategy(self, use_ds=True):
"""Gets the distribution strategy.""" """Gets the distribution strategy."""
return distribution_utils.get_distribution_strategy( if self.tpu:
distribution_strategy='mirrored' if use_ds else 'off', return distribution_utils.get_distribution_strategy(
num_gpus=self.num_gpus) distribution_strategy='tpu', tpu_address=self.tpu)
else:
return distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus)
@flagsaver.flagsaver @flagsaver.flagsaver
def _train_squad(self, use_ds=True, run_eagerly=False): def _train_squad(self, use_ds=True, run_eagerly=False):
...@@ -117,11 +125,12 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -117,11 +125,12 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
Tests BERT SQuAD performance in different GPU configurations. Tests BERT SQuAD performance in different GPU configurations.
The naming convention of below test cases follow The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format. `benchmark_(number of gpus)_gpu` format for GPUs and
`benchmark_(topology)_tpu` format for TPUs.
""" """
def __init__(self, output_dir=TMP_DIR, **kwargs): def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir) super(BertSquadBenchmarkReal, self).__init__(output_dir=output_dir, tpu=tpu)
def _setup(self): def _setup(self):
"""Sets up the benchmark and SQuAD flags.""" """Sets up the benchmark and SQuAD flags."""
...@@ -322,16 +331,26 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -322,16 +331,26 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_2x2_tpu(self):
"""Tests BERT SQuAD model performance with 2x2 TPU."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.train_batch_size = 48
self._run_and_report_benchmark()
class BertSquadAccuracy(BertSquadBenchmarkBase): class BertSquadAccuracy(BertSquadBenchmarkBase):
"""Short accuracy test for BERT SQuAD model. """Short accuracy test for BERT SQuAD model.
Tests BERT SQuAD accuracy. The naming convention of below test cases follow Tests BERT SQuAD accuracy. The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format. `benchmark_(number of gpus)_gpu` format for GPUs and
`benchmark_(topology)_tpu` format for TPUs.
""" """
def __init__(self, output_dir=None, **kwargs): def __init__(self, output_dir=None, tpu=None, **kwargs):
super(BertSquadAccuracy, self).__init__(output_dir=output_dir) super(BertSquadAccuracy, self).__init__(output_dir=output_dir, tpu=tpu)
def _setup(self): def _setup(self):
"""Sets up the benchmark and SQuAD flags.""" """Sets up the benchmark and SQuAD flags."""
...@@ -407,6 +426,15 @@ class BertSquadAccuracy(BertSquadBenchmarkBase): ...@@ -407,6 +426,15 @@ class BertSquadAccuracy(BertSquadBenchmarkBase):
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_2x2_tpu(self):
"""Tests BERT SQuAD model accuracy with 2x2 TPU."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu')
FLAGS.train_batch_size = 48
self._run_and_report_benchmark()
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment