Commit a617b671 authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 285322154
parent 0f5bdd0e
...@@ -51,10 +51,11 @@ FLAGS = flags.FLAGS ...@@ -51,10 +51,11 @@ FLAGS = flags.FLAGS
class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None): def __init__(self, output_dir=None, tpu=None):
super(BertClassifyBenchmarkBase, self).__init__(output_dir) super(BertClassifyBenchmarkBase, self).__init__(output_dir)
self.num_epochs = None self.num_epochs = None
self.num_steps_per_epoch = None self.num_steps_per_epoch = None
self.tpu = tpu
@flagsaver.flagsaver @flagsaver.flagsaver
def _run_bert_classifier(self, callbacks=None, use_ds=True): def _run_bert_classifier(self, callbacks=None, use_ds=True):
...@@ -72,6 +73,10 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -72,6 +73,10 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
warmup_steps = int(epochs * steps_per_epoch * 0.1) warmup_steps = int(epochs * steps_per_epoch * 0.1)
eval_steps = int( eval_steps = int(
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
if self.tpu:
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='tpu', tpu_address=self.tpu)
else:
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored' if use_ds else 'off', distribution_strategy='mirrored' if use_ds else 'off',
num_gpus=self.num_gpus) num_gpus=self.num_gpus)
...@@ -109,13 +114,15 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -109,13 +114,15 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
"""Short benchmark performance tests for BERT model. """Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU configurations. Tests BERT classification performance in different GPU, TPU configurations.
The naming convention of below test cases follow The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu_(dataset type)` format. `benchmark_(number of gpus)_gpu_(dataset type)` for GPUs and
`benchmark_(topology)_tpu_(dataset type)` for TPUs.
""" """
def __init__(self, output_dir=TMP_DIR, **kwargs): def __init__(self, output_dir=TMP_DIR, tpu=None, **kwargs):
super(BertClassifyBenchmarkReal, self).__init__(output_dir=output_dir) super(BertClassifyBenchmarkReal, self).__init__(
output_dir=output_dir, tpu=tpu)
self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
...@@ -289,6 +296,22 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase): ...@@ -289,6 +296,22 @@ class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
'summaries/training_summary.txt') 'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False) self._run_and_report_benchmark(summary_path, use_ds=False)
def benchmark_2x2_tpu_mrpc(self):
"""Test BERT model performance with 2x2 TPU."""
self._setup()
FLAGS.model_dir = self._get_model_dir('benchmark_2x2_tpu_mrpc')
FLAGS.train_data_path = self.train_data_path
FLAGS.eval_data_path = self.eval_data_path
FLAGS.input_meta_data_path = self.input_meta_data_path
FLAGS.bert_config_file = self.bert_config_file
FLAGS.train_batch_size = 32
FLAGS.eval_batch_size = 32
summary_path = os.path.join(FLAGS.model_dir,
'summaries/training_summary.txt')
self._run_and_report_benchmark(summary_path, use_ds=False)
class BertClassifyAccuracy(BertClassifyBenchmarkBase): class BertClassifyAccuracy(BertClassifyBenchmarkBase):
"""Short accuracy test for BERT model. """Short accuracy test for BERT model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment