Commit 415e8a45 authored by davidmochen's avatar davidmochen Committed by saberkun
Browse files

Add BERT SQuAD benchmark (#6976)

parent 42a8af1d
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions or classes shared between BERT benchmarks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
FLAGS = flags.FLAGS
class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
"""Callback that records time it takes to run each batch."""
def __init__(self, num_batches_to_skip=10):
super(BenchmarkTimerCallback, self).__init__()
self.num_batches_to_skip = num_batches_to_skip
self.timer_records = []
self.start_time = None
def on_batch_start(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
self.start_time = time.time()
def on_batch_end(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
assert self.start_time
self.timer_records.append(time.time() - self.start_time)
def get_examples_per_sec(self, batch_size):
return batch_size / np.mean(self.timer_records)
class BertBenchmarkBase(tf.test.Benchmark):
"""Base class to hold methods common to test classes."""
local_flags = None
def __init__(self, output_dir=None):
self.num_gpus = 8
if not output_dir:
output_dir = '/tmp'
self.output_dir = output_dir
self.timer_callback = None
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Sets up and resets flags before each test."""
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
self.timer_callback = BenchmarkTimerCallback()
if BertBenchmarkBase.local_flags is None:
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
saved_flag_values = flagsaver.save_flag_values()
BertBenchmarkBase.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(BertBenchmarkBase.local_flags)
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
"""
metrics = [{
'name': 'training_loss',
'value': stats['train_loss'],
}, {
'name':
'exp_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
}]
if 'eval_metrics' in stats:
metrics.append({
'name': 'eval_accuracy',
'value': stats['eval_metrics'],
'min_value': min_accuracy,
'max_value': max_accuracy,
})
self.report_benchmark(
iters=stats['total_training_steps'],
wall_time=wall_time_sec,
metrics=metrics)
......@@ -24,7 +24,6 @@ import os
import time
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
......@@ -32,6 +31,7 @@ import tensorflow as tf
from official.bert import modeling
from official.bert import run_classifier
from official.bert.benchmark import benchmark_utils
from official.utils.misc import distribution_utils
# pylint: disable=line-too-long
......@@ -45,95 +45,14 @@ MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1
FLAGS = flags.FLAGS
class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
"""Callback that records time it takes to run each batch."""
def __init__(self, num_batches_to_skip=10):
super(BenchmarkTimerCallback, self).__init__()
self.num_batches_to_skip = num_batches_to_skip
self.timer_records = []
self.start_time = None
def on_batch_start(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
self.start_time = time.time()
def on_batch_end(self, batch, logs=None):
if batch < self.num_batches_to_skip:
return
assert self.start_time
self.timer_records.append(time.time() - self.start_time)
def get_examples_per_sec(self, batch_size):
return batch_size / np.mean(self.timer_records)
class BertBenchmarkBase(tf.test.Benchmark):
class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
local_flags = None
def __init__(self, output_dir=None):
self.num_gpus = 8
self.num_epochs = None
self.num_steps_per_epoch = None
if not output_dir:
output_dir = '/tmp'
self.output_dir = output_dir
self.timer_callback = None
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
def _setup(self):
"""Sets up and resets flags before each test."""
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
self.timer_callback = BenchmarkTimerCallback()
if BertBenchmarkBase.local_flags is None:
# Loads flags to get defaults to then override. List cannot be empty.
flags.FLAGS(['foo'])
saved_flag_values = flagsaver.save_flag_values()
BertBenchmarkBase.local_flags = saved_flag_values
else:
flagsaver.restore_flag_values(BertBenchmarkBase.local_flags)
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
"""Report benchmark results by writing to local protobuf file.
Args:
stats: dict returned from BERT models with known entries.
wall_time_sec: the during of the benchmark execution in seconds
min_accuracy: Minimum classification accuracy constraint to verify
correctness of the model.
max_accuracy: Maximum classification accuracy constraint to verify
correctness of the model.
"""
metrics = [{
'name': 'training_loss',
'value': stats['train_loss'],
}, {
'name':
'exp_per_second',
'value':
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
}]
if 'eval_metrics' in stats:
metrics.append({
'name': 'eval_accuracy',
'value': stats['eval_metrics'],
'min_value': min_accuracy,
'max_value': max_accuracy,
})
self.report_benchmark(
iters=stats['total_training_steps'],
wall_time=wall_time_sec,
metrics=metrics)
super(BertClassifyBenchmarkBase, self).__init__(output_dir)
@flagsaver.flagsaver
def _run_bert_classifier(self, callbacks=None):
......@@ -168,7 +87,7 @@ class BertBenchmarkBase(tf.test.Benchmark):
custom_callbacks=callbacks)
class BertClassifyBenchmarkReal(BertBenchmarkBase):
class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
"""Short benchmark performance tests for BERT model.
Tests BERT classification performance in different GPU configurations.
......@@ -272,7 +191,7 @@ class BertClassifyBenchmarkReal(BertBenchmarkBase):
self._run_and_report_benchmark(summary_path)
class BertClassifyAccuracy(BertBenchmarkBase):
class BertClassifyAccuracy(BertClassifyBenchmarkBase):
"""Short accuracy test for BERT model.
Tests BERT classification task model accuracy. The naming
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes BERT SQuAD benchmarks and accuracy tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import time
# pylint: disable=g-bad-import-order
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.bert import run_squad
from official.bert.benchmark import benchmark_utils
from official.utils.misc import distribution_utils
# pylint: disable=line-too-long
SQUAD_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_train.tf_record'
SQUAD_PREDICT_FILE = 'gs://tf-perfzero-data/bert/squad/dev-v1.1.json'
SQUAD_VOCAB_FILE = 'gs://tf-perfzero-data/bert/squad/vocab.txt'
SQUAD_SMALL_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/squad/squad_small_meta_data'
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_config'
# pylint: enable=line-too-long
FLAGS = flags.FLAGS
class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
@flagsaver.flagsaver
def _run_bert_squad(self):
"""Starts BERT SQuAD task."""
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored', num_gpus=self.num_gpus)
run_squad.train_squad(
strategy=strategy,
input_meta_data=input_meta_data,
custom_callbacks=[self.timer_callback])
class BertSquadBenchmark(BertSquadBenchmarkBase):
"""Short benchmark performance tests for BERT SQuAD model.
Tests BERT SQuAD performance in different GPU configurations.
The naming convention of below test cases follow
`benchmark_(number of gpus)_gpu` format.
"""
def __init__(self, output_dir=None, **kwargs):
super(BertSquadBenchmark, self).__init__(output_dir=output_dir)
def _setup(self):
super(BertSquadBenchmark, self)._setup()
FLAGS.train_data_path = SQUAD_TRAIN_DATA_PATH
FLAGS.predict_file = SQUAD_PREDICT_FILE
FLAGS.vocab_file = SQUAD_VOCAB_FILE
FLAGS.input_meta_data_path = SQUAD_SMALL_INPUT_META_DATA_PATH
FLAGS.bert_config_file = MODEL_CONFIG_FILE_PATH
FLAGS.num_train_epochs = 1
def _run_and_report_benchmark(self,
training_summary_path,
min_accuracy=0,
max_accuracy=1):
"""Starts BERT SQuAD performance benchmark test."""
start_time_sec = time.time()
self._run_bert_squad()
wall_time_sec = time.time() - start_time_sec
with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
summary = json.loads(reader.read().decode('utf-8'))
super(BertSquadBenchmark, self)._report_benchmark(
stats=summary,
wall_time_sec=wall_time_sec,
min_accuracy=min_accuracy,
max_accuracy=max_accuracy)
def benchmark_1_gpu(self):
"""Test BERT SQuAD model performance with 1 GPU."""
self._setup()
self.num_gpus = 1
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_squad')
FLAGS.train_batch_size = 4
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_2_gpu(self):
"""Test BERT SQuAD model performance with 2 GPUs."""
self._setup()
self.num_gpus = 2
FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_squad')
FLAGS.train_batch_size = 8
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_4_gpu(self):
"""Test BERT SQuAD model performance with 4 GPUs."""
self._setup()
self.num_gpus = 4
FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_squad')
FLAGS.train_batch_size = 16
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
def benchmark_8_gpu(self):
"""Test BERT SQuAD model performance with 8 GPUs."""
self._setup()
self.num_gpus = 8
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_squad')
FLAGS.train_batch_size = 32
summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
self._run_and_report_benchmark(summary_path)
if __name__ == '__main__':
tf.test.main()
......@@ -287,8 +287,8 @@ def run_customized_training_loop(
if eval_metric_result:
training_summary['eval_metrics'] = eval_metric_result
summary_path = os.path.join(model_dir, SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
f.write(json.dumps(training_summary, indent=4))
summary_path = os.path.join(model_dir, SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
f.write(json.dumps(training_summary, indent=4))
return model
......@@ -189,7 +189,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
return all_results
def train_squad(strategy, input_meta_data):
def train_squad(strategy, input_meta_data, custom_callbacks=None):
"""Run bert squad training."""
if not strategy:
raise ValueError('Distribution strategy cannot be None.')
......@@ -233,7 +233,8 @@ def train_squad(strategy, input_meta_data):
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu)
use_remote_tpu=use_remote_tpu,
custom_callbacks=custom_callbacks)
def predict_squad(strategy, input_meta_data):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment