bert_models.py 14.6 KB
Newer Older
Frederick Liu's avatar
Frederick Liu committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
16
"""BERT models that are compatible with TF 2.0."""

Hongkun Yu's avatar
Hongkun Yu committed
17
import gin
18
import tensorflow as tf
19
import tensorflow_hub as hub
Le Hou's avatar
Le Hou committed
20
from official.legacy.nlp.albert import configs as albert_configs
21
from official.modeling import tf_utils
22
from official.nlp.bert import configs
23
from official.nlp.modeling import models
Hongkun Yu's avatar
Hongkun Yu committed
24
from official.nlp.modeling import networks
25
26
27
28
29


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
30
  def __init__(self, vocab_size, **kwargs):
31
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
32
33
34
35
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
36
37

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
38
39
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
40
    """Adds metrics."""
41
42
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
43
44
45
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
46
47
48
49
50
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

51
52
53
54
55
56
57
58
59
60
61
62
63
    if sentence_labels is not None:
      next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
          sentence_labels, sentence_output)
      self.add_metric(
          next_sentence_accuracy,
          name='next_sentence_accuracy',
          aggregation='mean')

    if next_sentence_loss is not None:
      self.add_metric(
          next_sentence_loss, name='next_sentence_loss', aggregation='mean')

  def call(self,
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
64
65
           lm_output_logits,
           sentence_output_logits,
66
67
68
           lm_label_ids,
           lm_label_weights,
           sentence_labels=None):
69
    """Implements call() for the layer."""
70
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
71
    lm_output_logits = tf.cast(lm_output_logits, tf.float32)
Chen Chen's avatar
Chen Chen committed
72

Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
73
74
75
76
77
78
    lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
        lm_label_ids, lm_output_logits, from_logits=True)
    lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
    lm_denominator_loss = tf.reduce_sum(lm_label_weights)
    mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
                                            lm_denominator_loss)
79
80

    if sentence_labels is not None:
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
81
82
83
84
      sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
      sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
          sentence_labels, sentence_output_logits, from_logits=True)
      sentence_loss = tf.reduce_mean(sentence_loss)
85
86
87
88
89
90
      loss = mask_label_loss + sentence_loss
    else:
      sentence_loss = None
      loss = mask_label_loss

    batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
91
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
92
    final_loss = tf.fill(batch_shape, loss)
93

Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
94
95
    self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
                      mask_label_loss, sentence_output_logits, sentence_labels,
Chen Chen's avatar
Chen Chen committed
96
                      sentence_loss)
97
98
99
    return final_loss


Hongkun Yu's avatar
Hongkun Yu committed
100
101
@gin.configurable
def get_transformer_encoder(bert_config,
Hongkun Yu's avatar
Hongkun Yu committed
102
                            sequence_length=None,
103
104
                            transformer_encoder_cls=None,
                            output_range=None):
105
106
107
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
108
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
Hongkun Yu's avatar
Hongkun Yu committed
109
    sequence_length: [Deprecated].
Hongkun Yu's avatar
Hongkun Yu committed
110
111
    transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
      default BERT encoder implementation.
112
113
    output_range: the sequence output range, [0, output_range). Default setting
      is to return the entire sequence output.
114
115

  Returns:
116
    A encoder object.
117
  """
Hongkun Yu's avatar
Hongkun Yu committed
118
  del sequence_length
Hongkun Yu's avatar
Hongkun Yu committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
  if transformer_encoder_cls is not None:
    # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
    embedding_cfg = dict(
        vocab_size=bert_config.vocab_size,
        type_vocab_size=bert_config.type_vocab_size,
        hidden_size=bert_config.hidden_size,
        max_seq_length=bert_config.max_position_embeddings,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
        dropout_rate=bert_config.hidden_dropout_prob,
    )
    hidden_cfg = dict(
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size=bert_config.intermediate_size,
        intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
        dropout_rate=bert_config.hidden_dropout_prob,
        attention_dropout_rate=bert_config.attention_probs_dropout_prob,
Chen Chen's avatar
Chen Chen committed
136
137
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
Hongkun Yu's avatar
Hongkun Yu committed
138
    )
139
140
141
142
143
    kwargs = dict(
        embedding_cfg=embedding_cfg,
        hidden_cfg=hidden_cfg,
        num_hidden_instances=bert_config.num_hidden_layers,
        pooled_output_dim=bert_config.hidden_size,
Chen Chen's avatar
Chen Chen committed
144
145
        pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range))
Hongkun Yu's avatar
Hongkun Yu committed
146
147
148
149

    # Relies on gin configuration to define the Transformer encoder arguments.
    return transformer_encoder_cls(**kwargs)

Chen Chen's avatar
Chen Chen committed
150
  kwargs = dict(
151
152
153
154
155
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
156
      activation=tf_utils.get_activation(bert_config.hidden_act),
157
158
159
160
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
161
      embedding_width=bert_config.embedding_size,
162
      initializer=tf.keras.initializers.TruncatedNormal(
Zongwei Zhou's avatar
Zongwei Zhou committed
163
          stddev=bert_config.initializer_range))
164
  if isinstance(bert_config, albert_configs.AlbertConfig):
Chen Chen's avatar
Chen Chen committed
165
    return networks.AlbertEncoder(**kwargs)
Chen Chen's avatar
Chen Chen committed
166
  else:
167
    assert isinstance(bert_config, configs.BertConfig)
168
    kwargs['output_range'] = output_range
169
    return networks.BertEncoder(**kwargs)
170
171


172
173
174
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
175
                   initializer=None,
Chen Chen's avatar
Chen Chen committed
176
177
                   use_next_sentence_label=True,
                   return_core_pretrainer_model=False):
178
179
180
181
182
183
184
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
185
      initializer: Initializer for weights in BertPretrainer.
186
      use_next_sentence_label: Whether to use the next sentence label.
Chen Chen's avatar
Chen Chen committed
187
188
      return_core_pretrainer_model: Whether to also return the `BertPretrainer`
        object.
189
190

  Returns:
Chen Chen's avatar
Chen Chen committed
191
192
193
      A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
      save weights after pretraining, and (3) optional core `BertPretrainer`
      object if argument `return_core_pretrainer_model` is True.
194
195
196
197
198
199
200
201
202
203
204
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
205
206
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
207
208
209
210
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
211
212
213
214
215
216

  if use_next_sentence_label:
    next_sentence_labels = tf.keras.layers.Input(
        shape=(1,), name='next_sentence_labels', dtype=tf.int32)
  else:
    next_sentence_labels = None
217

Chen Chen's avatar
Chen Chen committed
218
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
219
220
221
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
222
  pretrainer_model = models.BertPretrainer(
Chen Chen's avatar
Chen Chen committed
223
      network=transformer_encoder,
224
      embedding_table=transformer_encoder.get_embedding_table(),
Chen Chen's avatar
Chen Chen committed
225
      num_classes=2,  # The next sentence prediction label has two classes.
Hongkun Yu's avatar
Hongkun Yu committed
226
      activation=tf_utils.get_activation(bert_config.hidden_act),
Chen Chen's avatar
Chen Chen committed
227
      num_token_predictions=max_predictions_per_seq,
228
      initializer=initializer,
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
229
      output='logits')
230

Hongkun Yu's avatar
Hongkun Yu committed
231
  outputs = pretrainer_model(
Chen Chen's avatar
Chen Chen committed
232
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
Hongkun Yu's avatar
Hongkun Yu committed
233
234
  lm_output = outputs['masked_lm']
  sentence_output = outputs['classification']
Chen Chen's avatar
Chen Chen committed
235
236
  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
237
238
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
239
240
241
242
243
244
245
246
247
248
249
250
  inputs = {
      'input_word_ids': input_word_ids,
      'input_mask': input_mask,
      'input_type_ids': input_type_ids,
      'masked_lm_positions': masked_lm_positions,
      'masked_lm_ids': masked_lm_ids,
      'masked_lm_weights': masked_lm_weights,
  }
  if use_next_sentence_label:
    inputs['next_sentence_labels'] = next_sentence_labels

  keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
Chen Chen's avatar
Chen Chen committed
251
252
253
254
  if return_core_pretrainer_model:
    return keras_model, transformer_encoder, pretrainer_model
  else:
    return keras_model, transformer_encoder
255
256


Hongkun Yu's avatar
Hongkun Yu committed
257
258
259
def squad_model(bert_config,
                max_seq_length,
                initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
260
261
                hub_module_url=None,
                hub_module_trainable=True):
262
263
264
265
266
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
Chen Chen's avatar
Chen Chen committed
267
268
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
269
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
270
    hub_module_trainable: True to finetune layers in the hub module.
271
272

  Returns:
273
274
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
275
  """
Chen Chen's avatar
Chen Chen committed
276
277
278
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
279
  if not hub_module_url:
Zongwei Zhou's avatar
Zongwei Zhou committed
280
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
281
    return models.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
282
        network=bert_encoder, initializer=initializer), bert_encoder
283

284
  input_word_ids = tf.keras.layers.Input(
285
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
286
287
288
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
289
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
290
  core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
291
  pooled_output, sequence_output = core_model(
Chen Chen's avatar
Chen Chen committed
292
      [input_word_ids, input_mask, input_type_ids])
293
  bert_encoder = tf.keras.Model(
294
      inputs={
295
          'input_word_ids': input_word_ids,
296
          'input_mask': input_mask,
297
          'input_type_ids': input_type_ids,
298
      },
299
300
      outputs=[sequence_output, pooled_output],
      name='core_model')
301
  return models.BertSpanLabeler(
302
      network=bert_encoder, initializer=initializer), bert_encoder
303
304
305
306


def classifier_model(bert_config,
                     num_labels,
Hongkun Yu's avatar
Hongkun Yu committed
307
                     max_seq_length=None,
308
                     final_layer_initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
309
310
                     hub_module_url=None,
                     hub_module_trainable=True):
311
312
313
314
315
316
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
317
318
    bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
      ALBERT model.
319
320
321
322
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
323
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
324
    hub_module_trainable: True to finetune layers in the hub module.
325
326
327
328
329
330
331
332
333
334
335

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
336
  if not hub_module_url:
337
338
    bert_encoder = get_transformer_encoder(
        bert_config, max_seq_length, output_range=1)
339
    return models.BertClassifier(
Hongkun Yu's avatar
Hongkun Yu committed
340
341
342
343
344
345
346
347
348
349
350
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
351
  bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
Hongkun Yu's avatar
Hongkun Yu committed
352
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
353
354
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
355

356
  output = tf.keras.layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
357
      num_labels, kernel_initializer=initializer, name='output')(
358
359
360
361
362
363
364
365
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model