Commit 415e8a45 authored by davidmochen's avatar davidmochen Committed by saberkun
Browse files

Add BERT SQuAD benchmark (#6976)

parent 42a8af1d
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions or classes shared between BERT benchmarks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
FLAGS = flags.FLAGS
class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
"""Callback that records time it takes to run each batch."""
def __init__(self, num_batches_to_skip=10):
super(BenchmarkTimerCallback, self).__init__()
self.num_batches_to_skip = num_batches_to_skip
self.timer_records = []
self.start_time = None
def on_batch_start(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
self.start_time = time.time()
def on_batch_end(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
assert self.start_time
self.timer_records.append(time.time() - self.start_time)
def get_examples_per_sec(self, batch_size):
return batch_size / np.mean(self.timer_records)
class BertBenchmarkBase(tf.test.Benchmark):
"""Base class to hold methods common to test classes."""
local_flags = None
def __init__(self, output_dir=None):
self.num_gpus = 8
if not output_dir:
output_dir = '/tmp'
self.output_dir = output_dir
self.timer_callback = None
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Sets up and resets flags before each test."""
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
self.timer_callback = BenchmarkTimerCallback()
if BertBenchmarkBase.local_flags is None:
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
saved_flag_values = flagsaver.save_flag_values()
BertBenchmarkBase.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(BertBenchmarkBase.local_flags)
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
"""
metrics = [{
'name': 'training_loss',
'value': stats['train_loss'],
}, {
'name':
'exp_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
}]
if 'eval_metrics' in stats:
metrics.append({
'name': 'eval_accuracy',
'value': stats['eval_metrics'],
'min_value': min_accuracy,
'max_value': max_accuracy,
})
self.report_benchmark(
iters=stats['total_training_steps'],
wall_time=wall_time_sec,
metrics=metrics)
...@@ -24,7 +24,6 @@ import os ...@@ -24,7 +24,6 @@ import os
import time import time
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import numpy as np
from absl import flags from absl import flags
from absl.testing import flagsaver from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
...@@ -32,6 +31,7 @@ import tensorflow as tf ...@@ -32,6 +31,7 @@ import tensorflow as tf
from official.bert import modeling from official.bert import modeling
from official.bert import run_classifier from official.bert import run_classifier
from official.bert.benchmark import benchmark_utils
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
# pylint: disable=line-too-long # pylint: disable=line-too-long
...@@ -45,95 +45,14 @@ MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1 ...@@ -45,95 +45,14 @@ MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
class BenchmarkTimerCallback(tf.keras.callbacks.Callback): class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Callback that records time it takes to run each batch."""
def __init__(self, num_batches_to_skip=10):
super(BenchmarkTimerCallback, self).__init__()
self.num_batches_to_skip = num_batches_to_skip
self.timer_records = []
self.start_time = None
def on_batch_start(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
self.start_time = time.time()
def on_batch_end(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
assert self.start_time
self.timer_records.append(time.time() - self.start_time)
def get_examples_per_sec(self, batch_size):
return batch_size / np.mean(self.timer_records)
class BertBenchmarkBase(tf.test.Benchmark):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
local_flags = None
def __init__(self, output_dir=None): def __init__(self, output_dir=None):
self.num_gpus = 8
self.num_epochs = None self.num_epochs = None
self.num_steps_per_epoch = None self.num_steps_per_epoch = None
if not output_dir: super(BertClassifyBenchmarkBase, self).__init__(output_dir)
output_dir = '/tmp'
self.output_dir = output_dir
self.timer_callback = None
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Sets up and resets flags before each test."""
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
self.timer_callback = BenchmarkTimerCallback()
if BertBenchmarkBase.local_flags is None:
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
saved_flag_values = flagsaver.save_flag_values()
BertBenchmarkBase.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(BertBenchmarkBase.local_flags)
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
"""
metrics = [{
'name': 'training_loss',
'value': stats['train_loss'],
}, {
'name':
'exp_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
}]
if 'eval_metrics' in stats:
metrics.append({
'name': 'eval_accuracy',
'value': stats['eval_metrics'],
'min_value': min_accuracy,
'max_value': max_accuracy,
})
self.report_benchmark(
iters=stats['total_training_steps'],
wall_time=wall_time_sec,
metrics=metrics)
@flagsaver.flagsaver @flagsaver.flagsaver
def _run_bert_classifier(self, callbacks=None): def _run_bert_classifier(self, callbacks=None):
...@@ -168,7 +87,7 @@ class BertBenchmarkBase(tf.test.Benchmark): ...@@ -168,7 +87,7 @@ class BertBenchmarkBase(tf.test.Benchmark):
custom_callbacks=callbacks) custom_callbacks=callbacks)
class BertClassifyBenchmarkReal(BertBenchmarkBase): class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
"""Short benchmark performance tests for BERT model. """Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU configurations. Tests BERT classification performance in different GPU configurations.
...@@ -272,7 +191,7 @@ class BertClassifyBenchmarkReal(BertBenchmarkBase): ...@@ -272,7 +191,7 @@ class BertClassifyBenchmarkReal(BertBenchmarkBase):
self._run_and_report_benchmark(summary_path) self._run_and_report_benchmark(summary_path)
class BertClassifyAccuracy(BertBenchmarkBase): class BertClassifyAccuracy(BertClassifyBenchmarkBase):
"""Short accuracy test for BERT model. """Short accuracy test for BERT model.
Tests BERT classification task model accuracy. The naming Tests BERT classification task model accuracy. The naming
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes BERT SQuAD benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import time
# pylint: disable=g-bad-import-order
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.bert import run_squad
from official.bert.benchmark import benchmark_utils
from official.utils.misc import distribution_utils
# pylint: disable=line-too-long
SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt'
SQUAD_SMALL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_small_meta_data'
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_config'
# pylint: enable=line-too-long
FLAGS = flags.FLAGS
class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
@flagsaver.flagsaver
def _run_bert_squad(self):
"""Starts BERT SQuAD task."""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored', num_gpus=self.num_gpus)
run_squad.train_squad(
strategy=strategy,
input_meta_data=input_meta_data,
custom_callbacks=[self.timer_callback])
class BertSquadBenchmark(BertSquadBenchmarkBase):
"""Short benchmark performance tests for BERT SQuAD model.
Tests BERT SQuAD performance in different GPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
"""
def __init__(self, output_dir=None, **kwargs):
super(BertSquadBenchmark, self).__init__(output_dir=output_dir)
def _setup(self):
super(BertSquadBenchmark, self)._setup()
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_SMALL_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0,
max_accuracy=1):
"""Starts BERT SQuAD performance benchmark test."""
start_time_sec = time.time()
self._run_bert_squad()
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
super(BertSquadBenchmark, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
max_accuracy=max_accuracy)
def benchmark_1_gpu(self):
"""Test BERT SQuAD model performance with 1 GPU."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad')
FLAGS.train_batch_size = 4
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_2_gpu(self):
"""Test BERT SQuAD model performance with 2 GPUs."""
self._setup()
self.num_gpus = 2
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_squad')
FLAGS.train_batch_size = 8
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_4_gpu(self):
"""Test BERT SQuAD model performance with 4 GPUs."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_squad')
FLAGS.train_batch_size = 16
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_8_gpu(self):
"""Test BERT SQuAD model performance with 8 GPUs."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
FLAGS.train_batch_size = 32
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
if __name__ == '__main__':
tf.test.main()
...@@ -287,8 +287,8 @@ def run_customized_training_loop( ...@@ -287,8 +287,8 @@ def run_customized_training_loop(
if eval_metric_result: if eval_metric_result:
training_summary['eval_metrics'] = eval_metric_result training_summary['eval_metrics'] = eval_metric_result
summary_path = os.path.join(model_dir, SUMMARY_TXT) summary_path = os.path.join(model_dir, SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f: with tf.io.gfile.GFile(summary_path, 'wb') as f:
f.write(json.dumps(training_summary, indent=4)) f.write(json.dumps(training_summary, indent=4))
return model return model
...@@ -189,7 +189,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config, ...@@ -189,7 +189,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
return all_results return all_results
def train_squad(strategy, input_meta_data): def train_squad(strategy, input_meta_data, custom_callbacks=None):
"""Run bert squad training.""" """Run bert squad training."""
if not strategy: if not strategy:
raise ValueError('Distribution strategy cannot be None.') raise ValueError('Distribution strategy cannot be None.')
...@@ -233,7 +233,8 @@ def train_squad(strategy, input_meta_data): ...@@ -233,7 +233,8 @@ def train_squad(strategy, input_meta_data):
epochs=epochs, epochs=epochs,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu) use_remote_tpu=use_remote_tpu,
custom_callbacks=custom_callbacks)
def predict_squad(strategy, input_meta_data): def predict_squad(strategy, input_meta_data):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment