bert_models.py 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
22
import tensorflow_hub as hub
23

24
from official.modeling import tf_utils
Chen Chen's avatar
Chen Chen committed
25
from official.nlp.modeling import losses
Hongkun Yu's avatar
Hongkun Yu committed
26
27
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
Chen Chen's avatar
Chen Chen committed
28
from official.nlp.modeling.networks import bert_pretrainer
29
from official.nlp.modeling.networks import bert_span_labeler
30
31


Hongkun Yu's avatar
Hongkun Yu committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def gather_indexes(sequence_tensor, positions):
  """Gathers the vectors at the specific positions.

  Args:
      sequence_tensor: Sequence output of `BertModel` layer of shape
        (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
        hidden units of `BertModel` layer.
      positions: Positions ids of tokens in sequence to mask for pretraining of
        with dimension (batch_size, max_predictions_per_seq) where
        `max_predictions_per_seq` is maximum number of tokens to mask out and
        predict per each sequence.

  Returns:
      Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
      num_hidden).
  """
  sequence_shape = tf_utils.get_shape_list(
      sequence_tensor, name='sequence_output_tensor')
  batch_size = sequence_shape[0]
  seq_length = sequence_shape[1]
  width = sequence_shape[2]

  flat_offsets = tf.keras.backend.reshape(
      tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
  flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
  flat_sequence_tensor = tf.keras.backend.reshape(
      sequence_tensor, [batch_size * seq_length, width])
  output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

  return output_tensor


64
65
66
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
67
  def __init__(self, vocab_size, **kwargs):
68
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
69
70
71
72
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
73
74
75
76
77
78

  def __call__(self,
               lm_output,
               sentence_output=None,
               lm_label_ids=None,
               lm_label_weights=None,
79
80
               sentence_labels=None,
               **kwargs):
81
    inputs = tf_utils.pack_inputs([
82
83
84
        lm_output, sentence_output, lm_label_ids, lm_label_weights,
        sentence_labels
    ])
Hongkun Yu's avatar
Hongkun Yu committed
85
86
    return super(BertPretrainLossAndMetricLayer,
                 self).__call__(inputs, **kwargs)
87
88

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
89
90
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
91
    """Adds metrics."""
92
93
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
94
95
96
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
97
98
99
100
101
102
103
104
105
106
107
108
109
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

    next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        sentence_labels, sentence_output)
    self.add_metric(
        next_sentence_accuracy,
        name='next_sentence_accuracy',
        aggregation='mean')

    self.add_metric(
Chen Chen's avatar
Chen Chen committed
110
        next_sentence_loss, name='next_sentence_loss', aggregation='mean')
111
112

  def call(self, inputs):
113
    """Implements call() for the layer."""
114
    unpacked_inputs = tf_utils.unpack_inputs(inputs)
115
116
    lm_output = unpacked_inputs[0]
    sentence_output = unpacked_inputs[1]
117
    lm_label_ids = unpacked_inputs[2]
118
    lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3], tf.float32)
119
    sentence_labels = unpacked_inputs[4]
Chen Chen's avatar
Chen Chen committed
120
121
122
123
124

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
    sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=sentence_labels, predictions=sentence_output)
125
    loss = mask_label_loss + sentence_loss
Chen Chen's avatar
Chen Chen committed
126
    batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0], [1])
127
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
128
    final_loss = tf.fill(batch_shape, loss)
129
130

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
131
132
                      mask_label_loss, sentence_output, sentence_labels,
                      sentence_loss)
133
134
135
    return final_loss


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
136
137
138
def _get_transformer_encoder(bert_config,
                             sequence_length,
                             float_dtype=tf.float32):
139
140
141
142
143
  """Gets a 'TransformerEncoder' object.

  Args:
    bert_config: A 'modeling.BertConfig' object.
    sequence_length: Maximum sequence length of the training data.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
144
    float_dtype: tf.dtype, tf.float32 or tf.float16.
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

  Returns:
    A networks.TransformerEncoder object.
  """
  return networks.TransformerEncoder(
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
      activation=tf_utils.get_activation('gelu'),
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=sequence_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
162
163
          stddev=bert_config.initializer_range),
      float_dtype=float_dtype.name)
164
165


166
167
168
169
170
171
172
173
174
175
176
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
177
      initializer: Initializer for weights in BertPretrainer.
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
193
194
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
195
196
197
198
199
200
201
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
  next_sentence_labels = tf.keras.layers.Input(
      shape=(1,), name='next_sentence_labels', dtype=tf.int32)

Chen Chen's avatar
Chen Chen committed
202
203
204
205
206
207
208
209
  transformer_encoder = _get_transformer_encoder(bert_config, seq_length)
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
  pretrainer_model = bert_pretrainer.BertPretrainer(
      network=transformer_encoder,
      num_classes=2,  # The next sentence prediction label has two classes.
      num_token_predictions=max_predictions_per_seq,
210
      initializer=initializer,
Chen Chen's avatar
Chen Chen committed
211
      output='predictions')
212

Chen Chen's avatar
Chen Chen committed
213
214
215
216
217
  lm_output, sentence_output = pretrainer_model(
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
218
219
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
Chen Chen's avatar
Chen Chen committed
220
  keras_model = tf.keras.Model(
221
222
223
224
225
226
227
228
229
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids,
          'masked_lm_positions': masked_lm_positions,
          'masked_lm_ids': masked_lm_ids,
          'masked_lm_weights': masked_lm_weights,
          'next_sentence_labels': next_sentence_labels,
      },
Chen Chen's avatar
Chen Chen committed
230
231
      outputs=output_loss)
  return keras_model, transformer_encoder
232
233
234
235
236
237
238
239
240
241
242


class BertSquadLogitsLayer(tf.keras.layers.Layer):
  """Returns a layer that computes custom logits for BERT squad model."""

  def __init__(self, initializer=None, float_type=tf.float32, **kwargs):
    super(BertSquadLogitsLayer, self).__init__(**kwargs)
    self.initializer = initializer
    self.float_type = float_type

  def build(self, unused_input_shapes):
243
    """Implements build() for the layer."""
244
245
246
247
248
    self.final_dense = tf.keras.layers.Dense(
        units=2, kernel_initializer=self.initializer, name='final_dense')
    super(BertSquadLogitsLayer, self).build(unused_input_shapes)

  def call(self, inputs):
249
    """Implements call() for the layer."""
250
251
252
253
254
255
256
257
258
259
260
261
    sequence_output = inputs

    input_shape = sequence_output.shape.as_list()
    sequence_length = input_shape[1]
    num_hidden_units = input_shape[2]

    final_hidden_input = tf.keras.backend.reshape(sequence_output,
                                                  [-1, num_hidden_units])
    logits = self.final_dense(final_hidden_input)
    logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2])
    logits = tf.transpose(logits, [2, 0, 1])
    unstacked_logits = tf.unstack(logits, axis=0)
262
263
    if self.float_type == tf.float16:
      unstacked_logits = tf.cast(unstacked_logits, tf.float32)
264
265
266
    return unstacked_logits[0], unstacked_logits[1]


Hongkun Yu's avatar
Hongkun Yu committed
267
268
269
270
def squad_model(bert_config,
                max_seq_length,
                float_type,
                initializer=None,
Chen Chen's avatar
Chen Chen committed
271
                hub_module_url=None):
272
273
274
275
276
277
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
    float_type: tf.dtype, tf.float32 or tf.bfloat16.
Chen Chen's avatar
Chen Chen committed
278
279
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
280
    hub_module_url: TF-Hub path/url to Bert module.
281
282

  Returns:
283
284
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
285
  """
Chen Chen's avatar
Chen Chen committed
286
287
288
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
289
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
290
291
    bert_encoder = _get_transformer_encoder(bert_config, max_seq_length,
                                            float_type)
292
    return bert_span_labeler.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
293
        network=bert_encoder, initializer=initializer), bert_encoder
294

295
  input_word_ids = tf.keras.layers.Input(
296
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
297
298
299
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
300
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
Chen Chen's avatar
Chen Chen committed
301
302
303
304
305
306
  core_model = hub.KerasLayer(hub_module_url, trainable=True)
  _, sequence_output = core_model(
      [input_word_ids, input_mask, input_type_ids])
  # Sets the shape manually due to a bug in TF shape inference.
  # TODO(hongkuny): remove this once shape inference is correct.
  sequence_output.set_shape((None, max_seq_length, bert_config.hidden_size))
307
308
309
310
311
312
313

  squad_logits_layer = BertSquadLogitsLayer(
      initializer=initializer, float_type=float_type, name='squad_logits')
  start_logits, end_logits = squad_logits_layer(sequence_output)

  squad = tf.keras.Model(
      inputs={
314
          'input_word_ids': input_word_ids,
315
          'input_mask': input_mask,
316
          'input_type_ids': input_type_ids,
317
      },
318
      outputs=[start_logits, end_logits],
319
320
321
322
323
324
325
326
      name='squad_model')
  return squad, core_model


def classifier_model(bert_config,
                     float_type,
                     num_labels,
                     max_seq_length,
327
328
                     final_layer_initializer=None,
                     hub_module_url=None):
329
330
331
332
333
334
335
336
337
338
339
340
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
    bert_config: BertConfig, the config defines the core BERT model.
    float_type: dtype, tf.float32 or tf.bfloat16.
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
341
    hub_module_url: TF-Hub path/url to Bert module.
342
343
344
345
346
347
348
349
350
351
352

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
353
  if not hub_module_url:
354
    bert_encoder = _get_transformer_encoder(bert_config, max_seq_length)
Hongkun Yu's avatar
Hongkun Yu committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    return bert_classifier.BertClassifier(
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
  bert_model = hub.KerasLayer(hub_module_url, trainable=True)
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
369
370
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
371

372
373
374
375
376
377
378
379
380
381
382
383
384
  output = tf.keras.layers.Dense(
      num_labels,
      kernel_initializer=initializer,
      name='output',
      dtype=float_type)(
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model