bert_benchmark.py 11.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes BERT benchmarks and accuracy tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
22
import math
23
24
25
import os
import time

26
27
# pylint: disable=g-bad-import-order
import numpy as np
28
29
from absl import flags
from absl.testing import flagsaver
30
31
import tensorflow as tf
# pylint: enable=g-bad-import-order
32

33
from official.bert import modeling
34
from official.bert import run_classifier
35
from official.utils.misc import distribution_utils
36
37

# pylint: disable=line-too-long
38
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_model.ckpt'
39
40
41
CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_train.tf_record'
CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_eval.tf_record'
CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_meta_data'
42
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/tf_20/uncased_L-24_H-1024_A-16/bert_config'
43
44
45
46
47
# pylint: enable=line-too-long

FLAGS = flags.FLAGS


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class BenchmarkTimerCallback(tf.keras.callbacks.Callback):
  """Callback that records time it takes to run each batch."""

  def __init__(self, num_batches_to_skip=10):
    super(BenchmarkTimerCallback, self).__init__()
    self.num_batches_to_skip = num_batches_to_skip
    self.timer_records = []
    self.start_time = None

  def on_batch_start(self, batch, logs=None):
    if batch < self.num_batches_to_skip:
      return
    self.start_time = time.time()

  def on_batch_end(self, batch, logs=None):
    if batch < self.num_batches_to_skip:
      return

    assert self.start_time
    self.timer_records.append(time.time() - self.start_time)

  def get_examples_per_sec(self, batch_size):
    return batch_size / np.mean(self.timer_records)


73
74
75
76
77
class BertBenchmarkBase(tf.test.Benchmark):
  """Base class to hold methods common to test classes in the module."""
  local_flags = None

  def __init__(self, output_dir=None):
78
    self.num_gpus = 8
79
80
81
    self.num_epochs = None
    self.num_steps_per_epoch = None

82
83
84
    if not output_dir:
      output_dir = '/tmp'
    self.output_dir = output_dir
85
    self.timer_callback = None
86
87
88
89
90
91
92
93

  def _get_model_dir(self, folder_name):
    """Returns directory to store info, e.g. saved model and event log."""
    return os.path.join(self.output_dir, folder_name)

  def _setup(self):
    """Sets up and resets flags before each test."""
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
94
    self.timer_callback = BenchmarkTimerCallback()
95

96
    if BertBenchmarkBase.local_flags is None:
97
98
99
100
101
102
103
      # Loads flags to get defaults to then override. List cannot be empty.
      flags.FLAGS(['foo'])
      saved_flag_values = flagsaver.save_flag_values()
      BertBenchmarkBase.local_flags = saved_flag_values
    else:
      flagsaver.restore_flag_values(BertBenchmarkBase.local_flags)

104
  def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy):
105
106
107
108
109
    """Report benchmark results by writing to local protobuf file.

    Args:
      stats: dict returned from BERT models with known entries.
      wall_time_sec: the during of the benchmark execution in seconds
110
111
112
113
      min_accuracy: Minimum classification accuracy constraint to verify
        correctness of the model.
      max_accuracy: Maximum classification accuracy constraint to verify
        correctness of the model.
114
    """
115
116
117
    metrics = [{
        'name': 'training_loss',
        'value': stats['train_loss'],
118
119
    }, {
        'name':
120
            'exp_per_second',
121
122
        'value':
            self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size)
123
124
    }]

125
    if 'eval_metrics' in stats:
126
127
128
      metrics.append({
          'name': 'eval_accuracy',
          'value': stats['eval_metrics'],
129
130
          'min_value': min_accuracy,
          'max_value': max_accuracy,
131
132
133
134
135
136
137
      })

    self.report_benchmark(
        iters=stats['total_training_steps'],
        wall_time=wall_time_sec,
        metrics=metrics)

138
  @flagsaver.flagsaver
139
140
  def _run_bert_classifier(self, callbacks=None):
    """Starts BERT classification task."""
141
142
143
    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
      input_meta_data = json.loads(reader.read().decode('utf-8'))

144
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
145
146
147
148
149
150
151
    epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
    if self.num_steps_per_epoch:
      steps_per_epoch = self.num_steps_per_epoch
    else:
      train_data_size = input_meta_data['train_data_size']
      steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
    warmup_steps = int(epochs * steps_per_epoch * 0.1)
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    eval_steps = int(
        math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy='mirrored', num_gpus=self.num_gpus)

    run_classifier.run_customized_training(
        strategy,
        bert_config,
        input_meta_data,
        FLAGS.model_dir,
        epochs,
        steps_per_epoch,
        eval_steps,
        warmup_steps,
        FLAGS.learning_rate,
        FLAGS.init_checkpoint,
        custom_callbacks=callbacks)


171
class BertClassifyBenchmarkReal(BertBenchmarkBase):
172
173
174
175
176
177
  """Short benchmark performance tests for BERT model.

  Tests BERT classification performance in different GPU configurations.
  The naming convention of below test cases follow
  `benchmark_(number of gpus)_gpu_(dataset type)` format.
  """
178
179
180
181
182
183

  def __init__(self, output_dir=None, **kwargs):
    self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
    self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
    self.bert_config_file = MODEL_CONFIG_FILE_PATH
    self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
184
185
186
187
188
    # Since we only care about performance metrics, we limit
    # the number of training steps and epochs to prevent unnecessarily
    # long tests.
    self.num_steps_per_epoch = 110
    self.num_epochs = 1
189

190
    super(BertClassifyBenchmarkReal, self).__init__(output_dir=output_dir)
191

192
193
194
195
  def _run_and_report_benchmark(self,
                                training_summary_path,
                                min_accuracy=0,
                                max_accuracy=1):
196
197
198
    """Starts BERT performance benchmark test."""

    start_time_sec = time.time()
199
    self._run_bert_classifier(callbacks=[self.timer_callback])
200
201
202
203
204
205
206
207
    wall_time_sec = time.time() - start_time_sec

    with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
      summary = json.loads(reader.read().decode('utf-8'))

    # Since we do not load from any pretrained checkpoints, we ignore all
    # accuracy metrics.
    summary.pop('eval_metrics', None)
208
    super(BertClassifyBenchmarkReal, self)._report_benchmark(
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        stats=summary,
        wall_time_sec=wall_time_sec,
        min_accuracy=min_accuracy,
        max_accuracy=max_accuracy)

  def benchmark_1_gpu_mrpc(self):
    """Test BERT model performance with 1 GPU."""

    self._setup()
    self.num_gpus = 1
    FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc')
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 4
    FLAGS.eval_batch_size = 4

    summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
    self._run_and_report_benchmark(summary_path)

230
  def benchmark_2_gpu_mrpc(self):
231
232
233
234
235
236
237
238
239
240
241
    """Test BERT model performance with 2 GPUs."""

    self._setup()
    self.num_gpus = 2
    FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_mprc')
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 8
    FLAGS.eval_batch_size = 8
242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
    self._run_and_report_benchmark(summary_path)

  def benchmark_4_gpu_mrpc(self):
    """Test BERT model performance with 4 GPUs."""

    self._setup()
    self.num_gpus = 4
    FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_mrpc')
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 16

    summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
    self._run_and_report_benchmark(summary_path)

  def benchmark_8_gpu_mrpc(self):
262
263
264
    """Test BERT model performance with 8 GPUs."""

    self._setup()
265
    FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
266
267
268
269
270
271
272
273
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file

    summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
    self._run_and_report_benchmark(summary_path)

274

275
276
277
278
279
280
281
class BertClassifyAccuracy(BertBenchmarkBase):
  """Short accuracy test for BERT model.

  Tests BERT classification task model accuracy. The naming
  convention of below test cases follow
  `benchmark_(number of gpus)_gpu_(dataset type)` format.
  """
282
283
284
285
286
287

  def __init__(self, output_dir=None, **kwargs):
    self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
    self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
    self.bert_config_file = MODEL_CONFIG_FILE_PATH
    self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
288
    self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
289

290
    super(BertClassifyAccuracy, self).__init__(output_dir=output_dir)
291

292
293
294
295
  def _run_and_report_benchmark(self,
                                training_summary_path,
                                min_accuracy=0.84,
                                max_accuracy=0.88):
296
297
    """Starts BERT accuracy benchmark test."""

298
    start_time_sec = time.time()
299
    self._run_bert_classifier(callbacks=[self.timer_callback])
300
301
    wall_time_sec = time.time() - start_time_sec

302
303
304
    with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
      summary = json.loads(reader.read().decode('utf-8'))

305
306
307
308
309
    super(BertClassifyAccuracy, self)._report_benchmark(
        stats=summary,
        wall_time_sec=wall_time_sec,
        min_accuracy=min_accuracy,
        max_accuracy=max_accuracy)
310

311
312
313
314
315
316
317
  def benchmark_8_gpu_mrpc(self):
    """Run BERT model accuracy test with 8 GPUs.

    Due to comparatively small cardinality of  MRPC dataset, training
    accuracy metric has high variance between trainings. As so, we
    set the wide range of allowed accuracy (84% to 88%).
    """
318

319
    self._setup()
320
    FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
321
322
323
324
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
325
    FLAGS.init_checkpoint = self.pretrained_checkpoint_path
326

327
328
    summary_path = os.path.join(FLAGS.model_dir, 'training_summary.txt')
    self._run_and_report_benchmark(summary_path)
329
330
331
332


if __name__ == '__main__':
  tf.test.main()