bert_benchmark.py 13.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes BERT benchmarks and accuracy tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
21
import functools
22
import json
23
import math
24
25
26
import os
import time

27
# pylint: disable=g-bad-import-order
28
29
from absl import flags
from absl.testing import flagsaver
30
import tensorflow as tf
31
# pylint: enable=g-bad-import-order
32

33
from official.benchmark import bert_benchmark_utils as benchmark_utils
34
from official.nlp import bert_modeling as modeling
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
35
from official.nlp.bert import input_pipeline
36
from official.nlp.bert import run_classifier
37
from official.utils.misc import distribution_utils
38
from official.utils.testing import benchmark_wrappers
39
40

# pylint: disable=line-too-long
41
PRETRAINED_CHECKPOINT_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_model.ckpt'
42
43
44
CLASSIFIER_TRAIN_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_train.tf_record'
CLASSIFIER_EVAL_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_eval.tf_record'
CLASSIFIER_INPUT_META_DATA_PATH = 'gs://tf-perfzero-data/bert/classification/mrpc_meta_data'
David Chen's avatar
David Chen committed
45
MODEL_CONFIG_FILE_PATH = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16/bert_config.json'
46
47
# pylint: enable=line-too-long

David Chen's avatar
David Chen committed
48
TMP_DIR = os.getenv('TMPDIR')
49
50
51
FLAGS = flags.FLAGS


davidmochen's avatar
davidmochen committed
52
class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
53
54
55
  """Base class to hold methods common to test classes in the module."""

  def __init__(self, output_dir=None):
56
    super(BertClassifyBenchmarkBase, self).__init__(output_dir)
57
58
59
    self.num_epochs = None
    self.num_steps_per_epoch = None

60
  @flagsaver.flagsaver
61
  def _run_bert_classifier(self, callbacks=None, use_ds=True):
62
    """Starts BERT classification task."""
63
64
65
    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
      input_meta_data = json.loads(reader.read().decode('utf-8'))

66
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
67
68
69
70
71
72
73
    epochs = self.num_epochs if self.num_epochs else FLAGS.num_train_epochs
    if self.num_steps_per_epoch:
      steps_per_epoch = self.num_steps_per_epoch
    else:
      train_data_size = input_meta_data['train_data_size']
      steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
    warmup_steps = int(epochs * steps_per_epoch * 0.1)
74
75
76
    eval_steps = int(
        math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
    strategy = distribution_utils.get_distribution_strategy(
77
78
79
        distribution_strategy='mirrored' if use_ds else 'off',
        num_gpus=self.num_gpus)

80
    steps_per_loop = 1
81

Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
82
83
84
85
86
87
88
89
90
91
92
93
94
    max_seq_length = input_meta_data['max_seq_length']
    train_input_fn = functools.partial(
        input_pipeline.create_classifier_dataset,
        FLAGS.train_data_path,
        seq_length=max_seq_length,
        batch_size=FLAGS.train_batch_size)
    eval_input_fn = functools.partial(
        input_pipeline.create_classifier_dataset,
        FLAGS.eval_data_path,
        seq_length=max_seq_length,
        batch_size=FLAGS.eval_batch_size,
        is_training=False,
        drop_remainder=False)
95
    run_classifier.run_bert_classifier(
96
97
98
99
100
101
        strategy,
        bert_config,
        input_meta_data,
        FLAGS.model_dir,
        epochs,
        steps_per_epoch,
102
        steps_per_loop,
103
104
105
106
        eval_steps,
        warmup_steps,
        FLAGS.learning_rate,
        FLAGS.init_checkpoint,
Rajagopal Ananthanarayanan's avatar
Rajagopal Ananthanarayanan committed
107
108
        train_input_fn,
        eval_input_fn,
109
110
111
        custom_callbacks=callbacks)


davidmochen's avatar
davidmochen committed
112
class BertClassifyBenchmarkReal(BertClassifyBenchmarkBase):
113
114
115
116
117
118
  """Short benchmark performance tests for BERT model.

  Tests BERT classification performance in different GPU configurations.
  The naming convention of below test cases follow
  `benchmark_(number of gpus)_gpu_(dataset type)` format.
  """
119

David Chen's avatar
David Chen committed
120
  def __init__(self, output_dir=TMP_DIR, **kwargs):
121
122
    super(BertClassifyBenchmarkReal, self).__init__(output_dir=output_dir)

123
124
125
126
    self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
    self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
    self.bert_config_file = MODEL_CONFIG_FILE_PATH
    self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
127

128
129
130
131
132
    # Since we only care about performance metrics, we limit
    # the number of training steps and epochs to prevent unnecessarily
    # long tests.
    self.num_steps_per_epoch = 110
    self.num_epochs = 1
133

134
  @benchmark_wrappers.enable_runtime_flags
135
136
137
  def _run_and_report_benchmark(self,
                                training_summary_path,
                                min_accuracy=0,
138
                                max_accuracy=1,
139
                                use_ds=True):
140
141
    """Starts BERT performance benchmark test."""
    start_time_sec = time.time()
142
    self._run_bert_classifier(callbacks=[self.timer_callback], use_ds=use_ds)
143
144
145
146
147
148
149
150
    wall_time_sec = time.time() - start_time_sec

    with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
      summary = json.loads(reader.read().decode('utf-8'))

    # Since we do not load from any pretrained checkpoints, we ignore all
    # accuracy metrics.
    summary.pop('eval_metrics', None)
151
    super(BertClassifyBenchmarkReal, self)._report_benchmark(
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        stats=summary,
        wall_time_sec=wall_time_sec,
        min_accuracy=min_accuracy,
        max_accuracy=max_accuracy)

  def benchmark_1_gpu_mrpc(self):
    """Test BERT model performance with 1 GPU."""

    self._setup()
    self.num_gpus = 1
    FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc')
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 4
    FLAGS.eval_batch_size = 4

170
171
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
172
173
    self._run_and_report_benchmark(summary_path)

174
175
176
177
178
179
180
181
182
183
184
185
  def benchmark_1_gpu_mrpc_xla(self):
    """Test BERT model performance with 1 GPU."""

    self._setup()
    self.num_gpus = 1
    FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_xla')
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 4
    FLAGS.eval_batch_size = 4
186
    FLAGS.enable_xla = True
187

188
189
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
190
    self._run_and_report_benchmark(summary_path)
191
192
193
194
195
196
197
198
199
200
201
202
203
204

  def benchmark_1_gpu_mrpc_no_dist_strat(self):
    """Test BERT model performance with 1 GPU, no distribution strategy."""

    self._setup()
    self.num_gpus = 1
    FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_mrpc_no_dist_strat')
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 4
    FLAGS.eval_batch_size = 4

205
206
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
207
208
    self._run_and_report_benchmark(summary_path, use_ds=False)

209
  def benchmark_2_gpu_mrpc(self):
210
211
212
213
    """Test BERT model performance with 2 GPUs."""

    self._setup()
    self.num_gpus = 2
214
    FLAGS.model_dir = self._get_model_dir('benchmark_2_gpu_mrpc')
215
216
217
218
219
220
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 8
    FLAGS.eval_batch_size = 8
221

222
223
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    self._run_and_report_benchmark(summary_path)

  def benchmark_4_gpu_mrpc(self):
    """Test BERT model performance with 4 GPUs."""

    self._setup()
    self.num_gpus = 4
    FLAGS.model_dir = self._get_model_dir('benchmark_4_gpu_mrpc')
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 16

238
239
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
240
241
242
    self._run_and_report_benchmark(summary_path)

  def benchmark_8_gpu_mrpc(self):
243
244
245
    """Test BERT model performance with 8 GPUs."""

    self._setup()
246
    FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
247
248
249
250
251
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file

252
253
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
254
255
    self._run_and_report_benchmark(summary_path)

256
  def benchmark_1_gpu_amp_mrpc_no_dist_strat(self):
257
    """Performance for 1 GPU no DS with automatic mixed precision."""
258
259
    self._setup()
    self.num_gpus = 1
260
261
    FLAGS.model_dir = self._get_model_dir(
        'benchmark_1_gpu_amp_mrpc_no_dist_strat')
262
263
264
265
266
267
268
269
270
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 4
    FLAGS.eval_batch_size = 4
    FLAGS.dtype = 'fp16'
    FLAGS.fp16_implementation = 'graph_rewrite'

271
272
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
273
274
275
    self._run_and_report_benchmark(summary_path, use_ds=False)

  def benchmark_8_gpu_amp_mrpc(self):
276
277
    """Test BERT model performance with 8 GPUs with automatic mixed precision.
    """
278
279
280
281
282
283
284
285
286
287
288
289
290

    self._setup()
    self.num_gpus = 8
    FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp_mrpc')
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.train_batch_size = 32
    FLAGS.eval_batch_size = 32
    FLAGS.dtype = 'fp16'
    FLAGS.fp16_implementation = 'graph_rewrite'

291
292
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
293
    self._run_and_report_benchmark(summary_path, use_ds=False)
294
295


davidmochen's avatar
davidmochen committed
296
class BertClassifyAccuracy(BertClassifyBenchmarkBase):
297
298
299
300
301
302
  """Short accuracy test for BERT model.

  Tests BERT classification task model accuracy. The naming
  convention of below test cases follow
  `benchmark_(number of gpus)_gpu_(dataset type)` format.
  """
303

David Chen's avatar
David Chen committed
304
  def __init__(self, output_dir=TMP_DIR, **kwargs):
305
306
307
308
    self.train_data_path = CLASSIFIER_TRAIN_DATA_PATH
    self.eval_data_path = CLASSIFIER_EVAL_DATA_PATH
    self.bert_config_file = MODEL_CONFIG_FILE_PATH
    self.input_meta_data_path = CLASSIFIER_INPUT_META_DATA_PATH
309
    self.pretrained_checkpoint_path = PRETRAINED_CHECKPOINT_PATH
310

311
    super(BertClassifyAccuracy, self).__init__(output_dir=output_dir)
312

313
  @benchmark_wrappers.enable_runtime_flags
314
315
316
  def _run_and_report_benchmark(self,
                                training_summary_path,
                                min_accuracy=0.84,
317
                                max_accuracy=0.88):
318
319
    """Starts BERT accuracy benchmark test."""

320
    start_time_sec = time.time()
321
    self._run_bert_classifier(callbacks=[self.timer_callback])
322
323
    wall_time_sec = time.time() - start_time_sec

324
325
326
    with tf.io.gfile.GFile(training_summary_path, 'rb') as reader:
      summary = json.loads(reader.read().decode('utf-8'))

327
328
329
330
331
    super(BertClassifyAccuracy, self)._report_benchmark(
        stats=summary,
        wall_time_sec=wall_time_sec,
        min_accuracy=min_accuracy,
        max_accuracy=max_accuracy)
332

333
334
335
336
337
338
339
340
  def _setup(self):
    super(BertClassifyAccuracy, self)._setup()
    FLAGS.train_data_path = self.train_data_path
    FLAGS.eval_data_path = self.eval_data_path
    FLAGS.input_meta_data_path = self.input_meta_data_path
    FLAGS.bert_config_file = self.bert_config_file
    FLAGS.init_checkpoint = self.pretrained_checkpoint_path

341
342
343
344
345
346
347
  def benchmark_8_gpu_mrpc(self):
    """Run BERT model accuracy test with 8 GPUs.

    Due to comparatively small cardinality of  MRPC dataset, training
    accuracy metric has high variance between trainings. As so, we
    set the wide range of allowed accuracy (84% to 88%).
    """
348
    self._setup()
349
    FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc')
350

351
352
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
353
    self._run_and_report_benchmark(summary_path)
354

355
356
357
358
  def benchmark_8_gpu_mrpc_xla(self):
    """Run BERT model accuracy test with 8 GPUs with XLA."""
    self._setup()
    FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_mrpc_xla')
359
    FLAGS.enable_xla = True
360
361
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
362
    self._run_and_report_benchmark(summary_path)
363

364
365
366

if __name__ == '__main__':
  tf.test.main()