bert_models.py 11.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
22
import tensorflow_hub as hub
23

24
from official.modeling import tf_utils
25
from official.nlp.albert import configs as albert_configs
26
from official.nlp.bert import configs
Chen Chen's avatar
Chen Chen committed
27
from official.nlp.modeling import losses
Hongkun Yu's avatar
Hongkun Yu committed
28
29
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
Chen Chen's avatar
Chen Chen committed
30
from official.nlp.modeling.networks import bert_pretrainer
31
from official.nlp.modeling.networks import bert_span_labeler
32
33
34
35
36


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
37
  def __init__(self, vocab_size, **kwargs):
38
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
39
40
41
42
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
43
44

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
45
46
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
47
    """Adds metrics."""
48
49
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
50
51
52
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
53
54
55
56
57
58
59
60
61
62
63
64
65
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

    next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        sentence_labels, sentence_output)
    self.add_metric(
        next_sentence_accuracy,
        name='next_sentence_accuracy',
        aggregation='mean')

    self.add_metric(
Chen Chen's avatar
Chen Chen committed
66
        next_sentence_loss, name='next_sentence_loss', aggregation='mean')
67

68
69
  def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
           sentence_labels):
70
    """Implements call() for the layer."""
71
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
72
73
    lm_output = tf.cast(lm_output, tf.float32)
    sentence_output = tf.cast(sentence_output, tf.float32)
Chen Chen's avatar
Chen Chen committed
74
75
76
77
78

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
    sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=sentence_labels, predictions=sentence_output)
79
    loss = mask_label_loss + sentence_loss
80
    batch_shape = tf.slice(tf.shape(sentence_labels), [0], [1])
81
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
82
    final_loss = tf.fill(batch_shape, loss)
83
84

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
85
86
                      mask_label_loss, sentence_output, sentence_labels,
                      sentence_loss)
87
88
89
    return final_loss


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
90
def get_transformer_encoder(bert_config, sequence_length):
91
92
93
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
94
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
95
96
97
98
99
    sequence_length: Maximum sequence length of the training data.

  Returns:
    A networks.TransformerEncoder object.
  """
Chen Chen's avatar
Chen Chen committed
100
  kwargs = dict(
101
102
103
104
105
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
106
      activation=tf_utils.get_activation(bert_config.hidden_act),
107
108
109
110
111
112
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=sequence_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
Zongwei Zhou's avatar
Zongwei Zhou committed
113
          stddev=bert_config.initializer_range))
114
  if isinstance(bert_config, albert_configs.AlbertConfig):
Chen Chen's avatar
Chen Chen committed
115
116
117
    kwargs['embedding_width'] = bert_config.embedding_size
    return networks.AlbertTransformerEncoder(**kwargs)
  else:
118
    assert isinstance(bert_config, configs.BertConfig)
Chen Chen's avatar
Chen Chen committed
119
    return networks.TransformerEncoder(**kwargs)
120
121


122
123
124
125
126
127
128
129
130
131
132
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
133
      initializer: Initializer for weights in BertPretrainer.
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
149
150
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
151
152
153
154
155
156
157
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
  next_sentence_labels = tf.keras.layers.Input(
      shape=(1,), name='next_sentence_labels', dtype=tf.int32)

Chen Chen's avatar
Chen Chen committed
158
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
159
160
161
162
163
164
165
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
  pretrainer_model = bert_pretrainer.BertPretrainer(
      network=transformer_encoder,
      num_classes=2,  # The next sentence prediction label has two classes.
      num_token_predictions=max_predictions_per_seq,
166
      initializer=initializer,
Chen Chen's avatar
Chen Chen committed
167
      output='predictions')
168

Chen Chen's avatar
Chen Chen committed
169
170
171
172
173
  lm_output, sentence_output = pretrainer_model(
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
174
175
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
Chen Chen's avatar
Chen Chen committed
176
  keras_model = tf.keras.Model(
177
178
179
180
181
182
183
184
185
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids,
          'masked_lm_positions': masked_lm_positions,
          'masked_lm_ids': masked_lm_ids,
          'masked_lm_weights': masked_lm_weights,
          'next_sentence_labels': next_sentence_labels,
      },
Chen Chen's avatar
Chen Chen committed
186
187
      outputs=output_loss)
  return keras_model, transformer_encoder
188
189


Hongkun Yu's avatar
Hongkun Yu committed
190
191
192
def squad_model(bert_config,
                max_seq_length,
                initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
193
194
                hub_module_url=None,
                hub_module_trainable=True):
195
196
197
198
199
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
Chen Chen's avatar
Chen Chen committed
200
201
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
202
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
203
    hub_module_trainable: True to finetune layers in the hub module.
204
205

  Returns:
206
207
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
208
  """
Chen Chen's avatar
Chen Chen committed
209
210
211
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
212
  if not hub_module_url:
Zongwei Zhou's avatar
Zongwei Zhou committed
213
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
214
    return bert_span_labeler.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
215
        network=bert_encoder, initializer=initializer), bert_encoder
216

217
  input_word_ids = tf.keras.layers.Input(
218
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
219
220
221
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
222
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
223
  core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
224
  pooled_output, sequence_output = core_model(
Chen Chen's avatar
Chen Chen committed
225
      [input_word_ids, input_mask, input_type_ids])
226
  bert_encoder = tf.keras.Model(
227
      inputs={
228
          'input_word_ids': input_word_ids,
229
          'input_mask': input_mask,
230
          'input_type_ids': input_type_ids,
231
      },
232
233
234
235
      outputs=[sequence_output, pooled_output],
      name='core_model')
  return bert_span_labeler.BertSpanLabeler(
      network=bert_encoder, initializer=initializer), bert_encoder
236
237
238
239
240


def classifier_model(bert_config,
                     num_labels,
                     max_seq_length,
241
                     final_layer_initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
242
243
                     hub_module_url=None,
                     hub_module_trainable=True):
244
245
246
247
248
249
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
250
251
    bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
      ALBERT model.
252
253
254
255
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
256
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
257
    hub_module_trainable: True to finetune layers in the hub module.
258
259
260
261
262
263
264
265
266
267
268

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
269
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
270
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
Hongkun Yu's avatar
Hongkun Yu committed
271
272
273
274
275
276
277
278
279
280
281
282
    return bert_classifier.BertClassifier(
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
283
284
  bert_model = hub.KerasLayer(
      hub_module_url, trainable=hub_module_trainable)
Hongkun Yu's avatar
Hongkun Yu committed
285
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
286
287
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
288

289
  output = tf.keras.layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
290
      num_labels, kernel_initializer=initializer, name='output')(
291
292
293
294
295
296
297
298
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model