bert_pretrain_benchmark.py 6.21 KB
Newer Older
Chen Chen's avatar
Chen Chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes benchmark testing for bert pretraining."""
# pylint: disable=line-too-long
from __future__ import print_function

import json
import os
import time
from typing import Optional

from absl import flags
from absl import logging
import tensorflow as tf  # pylint: disable=g-bad-import-order

from official.benchmark import benchmark_wrappers
from official.benchmark import bert_benchmark_utils
Jing Li's avatar
Jing Li committed
31
from official.benchmark import owner_utils
Chen Chen's avatar
Chen Chen committed
32
33
34
35
36
from official.nlp.bert import run_pretraining
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils

# Pretrain masked lanauge modeling accuracy range:
Chen Chen's avatar
Chen Chen committed
37
38
MIN_MLM_ACCURACY = 0.635
MAX_MLM_ACCURACY = 0.645
Chen Chen's avatar
Chen Chen committed
39
40

# Pretrain next sentence prediction accuracy range:
Chen Chen's avatar
Chen Chen committed
41
42
MIN_NSP_ACCURACY = 0.94
MAX_NSP_ACCURACY = 0.96
Chen Chen's avatar
Chen Chen committed
43
44
45
46
47
48
49
50
51
52
53
54

BERT_PRETRAIN_FILES_SEQ128 = 'gs://mlcompass-data/bert/pretraining_data/seq_128/wikipedia.tfrecord*,gs://mlcompass-data/bert/pretraining_data/seq_128/books.tfrecord*'
BERT_BASE_CONFIG_FILE = 'gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12/bert_config.json'

FLAGS = flags.FLAGS


class BertPretrainAccuracyBenchmark(bert_benchmark_utils.BertBenchmarkBase):
  """Benchmark accuracy tests for BERT Pretraining."""

  def __init__(self,
               output_dir: Optional[str] = None,
Chen Chen's avatar
Chen Chen committed
55
56
               tpu: Optional[str] = None,
               **kwargs):
Chen Chen's avatar
Chen Chen committed
57
58
59
60
61
    """Inits BertPretrainAccuracyBenchmark class.

    Args:
      output_dir: Directory where to output e.g. log files
      tpu: TPU name to use in a TPU benchmark.
Chen Chen's avatar
Chen Chen committed
62
      **kwargs: Additional keyword arguments.
Chen Chen's avatar
Chen Chen committed
63
64
    """
    super(BertPretrainAccuracyBenchmark, self).__init__(
Chen Chen's avatar
Chen Chen committed
65
        output_dir=output_dir, tpu=tpu, **kwargs)
Chen Chen's avatar
Chen Chen committed
66
67

  @benchmark_wrappers.enable_runtime_flags
Jing Li's avatar
Jing Li committed
68
  def _run_and_report_benchmark(self, summary_path: str, report_accuracy: bool):
Chen Chen's avatar
Chen Chen committed
69
70
    """Runs and reports the benchmark given the provided configuration."""
    distribution = distribution_utils.get_distribution_strategy(
Sai Ganesh Bandiatmakuri's avatar
Sai Ganesh Bandiatmakuri committed
71
        distribution_strategy='tpu', tpu_address=self.tpu)
Chen Chen's avatar
Chen Chen committed
72
73
74
75
76
77
78
79
    logging.info('Flags: %s', flags_core.get_nondefault_flags_as_str())
    start_time_sec = time.time()
    run_pretraining.run_bert_pretrain(
        strategy=distribution, custom_callbacks=self.timer_callback)
    wall_time_sec = time.time() - start_time_sec

    with tf.io.gfile.GFile(summary_path, 'rb') as reader:
      summary = json.loads(reader.read().decode('utf-8'))
Jing Li's avatar
Jing Li committed
80
81
    self._report_benchmark(summary, start_time_sec, wall_time_sec,
                           report_accuracy)
Chen Chen's avatar
Chen Chen committed
82

Jing Li's avatar
Jing Li committed
83
84
  def _report_benchmark(self, summary, start_time_sec, wall_time_sec,
                        report_accuracy):
Chen Chen's avatar
Chen Chen committed
85
86
87
88
89
    metrics = [{
        'name': 'train_loss',
        'value': summary['train_loss'],
    }, {
        'name':
Jing Li's avatar
Jing Li committed
90
            'exp_per_second',
Chen Chen's avatar
Chen Chen committed
91
92
93
94
95
96
97
        'value':
            self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size *
                                                     FLAGS.steps_per_loop)
    }, {
        'name': 'startup_time',
        'value': self.timer_callback.get_startup_time(start_time_sec)
    }]
Jing Li's avatar
Jing Li committed
98
99
100
101
102
103
104
105
106
107
108
109
    if report_accuracy:
      metrics.extend([{
          'name': 'masked_lm_accuracy',
          'value': summary['masked_lm_accuracy'],
          'min_value': MIN_MLM_ACCURACY,
          'max_value': MAX_MLM_ACCURACY,
      }, {
          'name': 'next_sentence_accuracy',
          'value': summary['next_sentence_accuracy'],
          'min_value': MIN_NSP_ACCURACY,
          'max_value': MAX_NSP_ACCURACY,
      }])
Chen Chen's avatar
Chen Chen committed
110
111
112
113
114
115
116
117
118
119
120
    self.report_benchmark(
        iters=summary['total_training_steps'],
        wall_time=wall_time_sec,
        metrics=metrics,
        extras={'flags': flags_core.get_nondefault_flags_as_str()})

  def _specify_common_flags(self):
    FLAGS.bert_config_file = BERT_BASE_CONFIG_FILE
    FLAGS.train_batch_size = 512
    FLAGS.learning_rate = 1e-4
    FLAGS.warmup_steps = 10000
Chen Chen's avatar
Chen Chen committed
121
    FLAGS.steps_per_loop = 10000
Chen Chen's avatar
Chen Chen committed
122
123
124
125
126
127
    FLAGS.distribution_strategy = 'tpu'
    FLAGS.input_files = BERT_PRETRAIN_FILES_SEQ128
    FLAGS.max_seq_length = 128
    FLAGS.max_predictions_per_seq = 20
    FLAGS.dtype = 'bf16'

Jing Li's avatar
Jing Li committed
128
  @owner_utils.Owner('tf-model-garden')
Chen Chen's avatar
Chen Chen committed
129
130
  def benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps(self):
    """Test bert pretraining with 8x8 TPU for 500k steps."""
Chen Chen's avatar
Chen Chen committed
131
132
133
    # This is used for accuracy test.
    self._setup()
    self._specify_common_flags()
Chen Chen's avatar
Chen Chen committed
134
    FLAGS.num_steps_per_epoch = 500000
Chen Chen's avatar
Chen Chen committed
135
    FLAGS.num_train_epochs = 1
Chen Chen's avatar
Chen Chen committed
136
    FLAGS.model_dir = self._get_model_dir(
Chen Chen's avatar
Chen Chen committed
137
        'benchmark_accuracy_8x8_tpu_bf16_seq128_500k_steps')
Chen Chen's avatar
Chen Chen committed
138
139
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
Chen Chen's avatar
Chen Chen committed
140
141
142
143
    # Set train_summary_interval to -1 to disable training summary, because
    # writing summary to gcs may fail and summaries are not needed for this
    # accuracy benchmark test.
    FLAGS.train_summary_interval = -1
Jing Li's avatar
Jing Li committed
144
145
    self._run_and_report_benchmark(summary_path=summary_path,
                                   report_accuracy=True)
Chen Chen's avatar
Chen Chen committed
146

Jing Li's avatar
Jing Li committed
147
148
149
  @owner_utils.Owner('tf-model-garden')
  def benchmark_perf_8x8_tpu_bf16_seq128_10k_steps(self):
    """Test bert pretraining with 8x8 TPU for 10000 steps."""
Chen Chen's avatar
Chen Chen committed
150
151
    self._setup()
    self._specify_common_flags()
Jing Li's avatar
Jing Li committed
152
153
    FLAGS.num_steps_per_epoch = 5000
    FLAGS.num_train_epochs = 2
Chen Chen's avatar
Chen Chen committed
154
    FLAGS.model_dir = self._get_model_dir(
Jing Li's avatar
Jing Li committed
155
        'benchmark_perf_8x8_tpu_bf16_seq128_10k_steps')
Chen Chen's avatar
Chen Chen committed
156
157
    summary_path = os.path.join(FLAGS.model_dir,
                                'summaries/training_summary.txt')
Jing Li's avatar
Jing Li committed
158
159
160
    # Disable accuracy check.
    self._run_and_report_benchmark(summary_path=summary_path,
                                   report_accuracy=False)
Chen Chen's avatar
Chen Chen committed
161
162
163
164


if __name__ == '__main__':
  tf.test.main()