bert_models.py 14.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
22
import tensorflow_hub as hub
23

24
from official.modeling import tf_utils
Chen Chen's avatar
Chen Chen committed
25
from official.nlp import bert_modeling
Chen Chen's avatar
Chen Chen committed
26
from official.nlp.modeling import losses
Hongkun Yu's avatar
Hongkun Yu committed
27
28
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
Chen Chen's avatar
Chen Chen committed
29
from official.nlp.modeling.networks import bert_pretrainer
30
from official.nlp.modeling.networks import bert_span_labeler
31
32


Hongkun Yu's avatar
Hongkun Yu committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def gather_indexes(sequence_tensor, positions):
  """Gathers the vectors at the specific positions.

  Args:
      sequence_tensor: Sequence output of `BertModel` layer of shape
        (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
        hidden units of `BertModel` layer.
      positions: Positions ids of tokens in sequence to mask for pretraining of
        with dimension (batch_size, max_predictions_per_seq) where
        `max_predictions_per_seq` is maximum number of tokens to mask out and
        predict per each sequence.

  Returns:
      Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
      num_hidden).
  """
  sequence_shape = tf_utils.get_shape_list(
      sequence_tensor, name='sequence_output_tensor')
  batch_size = sequence_shape[0]
  seq_length = sequence_shape[1]
  width = sequence_shape[2]

  flat_offsets = tf.keras.backend.reshape(
      tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
  flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
  flat_sequence_tensor = tf.keras.backend.reshape(
      sequence_tensor, [batch_size * seq_length, width])
  output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

  return output_tensor


65
66
67
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
68
  def __init__(self, vocab_size, **kwargs):
69
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
70
71
72
73
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
74
75

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
76
77
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
78
    """Adds metrics."""
79
80
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
81
82
83
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
84
85
86
87
88
89
90
91
92
93
94
95
96
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

    next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        sentence_labels, sentence_output)
    self.add_metric(
        next_sentence_accuracy,
        name='next_sentence_accuracy',
        aggregation='mean')

    self.add_metric(
Chen Chen's avatar
Chen Chen committed
97
        next_sentence_loss, name='next_sentence_loss', aggregation='mean')
98

99
100
  def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
           sentence_labels):
101
    """Implements call() for the layer."""
102
    lm_label_weights = tf.keras.backend.cast(lm_label_weights, tf.float32)
Chen Chen's avatar
Chen Chen committed
103
104
105
106
107

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
    sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=sentence_labels, predictions=sentence_output)
108
    loss = mask_label_loss + sentence_loss
Chen Chen's avatar
Chen Chen committed
109
    batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0], [1])
110
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
111
    final_loss = tf.fill(batch_shape, loss)
112
113

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
114
115
                      mask_label_loss, sentence_output, sentence_labels,
                      sentence_loss)
116
117
118
    return final_loss


Chen Chen's avatar
Chen Chen committed
119
120
121
def get_transformer_encoder(bert_config,
                            sequence_length,
                            float_dtype=tf.float32):
122
123
124
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
125
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
126
    sequence_length: Maximum sequence length of the training data.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
127
    float_dtype: tf.dtype, tf.float32 or tf.float16.
128
129
130
131

  Returns:
    A networks.TransformerEncoder object.
  """
Chen Chen's avatar
Chen Chen committed
132
  kwargs = dict(
133
134
135
136
137
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
138
      activation=tf_utils.get_activation(bert_config.hidden_act),
139
140
141
142
143
144
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=sequence_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
145
146
          stddev=bert_config.initializer_range),
      float_dtype=float_dtype.name)
Chen Chen's avatar
Chen Chen committed
147
148
149
150
151
152
  if isinstance(bert_config, bert_modeling.AlbertConfig):
    kwargs['embedding_width'] = bert_config.embedding_size
    return networks.AlbertTransformerEncoder(**kwargs)
  else:
    assert isinstance(bert_config, bert_modeling.BertConfig)
    return networks.TransformerEncoder(**kwargs)
153
154


155
156
157
158
159
160
161
162
163
164
165
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
166
      initializer: Initializer for weights in BertPretrainer.
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
182
183
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
184
185
186
187
188
189
190
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
  next_sentence_labels = tf.keras.layers.Input(
      shape=(1,), name='next_sentence_labels', dtype=tf.int32)

Chen Chen's avatar
Chen Chen committed
191
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
192
193
194
195
196
197
198
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
  pretrainer_model = bert_pretrainer.BertPretrainer(
      network=transformer_encoder,
      num_classes=2,  # The next sentence prediction label has two classes.
      num_token_predictions=max_predictions_per_seq,
199
      initializer=initializer,
Chen Chen's avatar
Chen Chen committed
200
      output='predictions')
201

Chen Chen's avatar
Chen Chen committed
202
203
204
205
206
  lm_output, sentence_output = pretrainer_model(
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
207
208
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
Chen Chen's avatar
Chen Chen committed
209
  keras_model = tf.keras.Model(
210
211
212
213
214
215
216
217
218
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids,
          'masked_lm_positions': masked_lm_positions,
          'masked_lm_ids': masked_lm_ids,
          'masked_lm_weights': masked_lm_weights,
          'next_sentence_labels': next_sentence_labels,
      },
Chen Chen's avatar
Chen Chen committed
219
220
      outputs=output_loss)
  return keras_model, transformer_encoder
221
222
223
224
225
226
227
228
229
230
231


class BertSquadLogitsLayer(tf.keras.layers.Layer):
  """Returns a layer that computes custom logits for BERT squad model."""

  def __init__(self, initializer=None, float_type=tf.float32, **kwargs):
    super(BertSquadLogitsLayer, self).__init__(**kwargs)
    self.initializer = initializer
    self.float_type = float_type

  def build(self, unused_input_shapes):
232
    """Implements build() for the layer."""
233
234
235
236
237
    self.final_dense = tf.keras.layers.Dense(
        units=2, kernel_initializer=self.initializer, name='final_dense')
    super(BertSquadLogitsLayer, self).build(unused_input_shapes)

  def call(self, inputs):
238
    """Implements call() for the layer."""
239
240
    sequence_output = inputs

241
242
    input_shape = tf_utils.get_shape_list(
        sequence_output, name='sequence_output_tensor')
243
244
245
246
247
248
249
250
251
    sequence_length = input_shape[1]
    num_hidden_units = input_shape[2]

    final_hidden_input = tf.keras.backend.reshape(sequence_output,
                                                  [-1, num_hidden_units])
    logits = self.final_dense(final_hidden_input)
    logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2])
    logits = tf.transpose(logits, [2, 0, 1])
    unstacked_logits = tf.unstack(logits, axis=0)
252
253
    if self.float_type == tf.float16:
      unstacked_logits = tf.cast(unstacked_logits, tf.float32)
254
255
256
    return unstacked_logits[0], unstacked_logits[1]


Hongkun Yu's avatar
Hongkun Yu committed
257
258
259
260
def squad_model(bert_config,
                max_seq_length,
                float_type,
                initializer=None,
Chen Chen's avatar
Chen Chen committed
261
                hub_module_url=None):
262
263
264
265
266
267
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
    float_type: tf.dtype, tf.float32 or tf.bfloat16.
Chen Chen's avatar
Chen Chen committed
268
269
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
270
    hub_module_url: TF-Hub path/url to Bert module.
271
272

  Returns:
273
274
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
275
  """
Chen Chen's avatar
Chen Chen committed
276
277
278
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
279
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
280
281
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length,
                                           float_type)
282
    return bert_span_labeler.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
283
        network=bert_encoder, initializer=initializer), bert_encoder
284

285
  input_word_ids = tf.keras.layers.Input(
286
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
287
288
289
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
290
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
Chen Chen's avatar
Chen Chen committed
291
292
293
  core_model = hub.KerasLayer(hub_module_url, trainable=True)
  _, sequence_output = core_model(
      [input_word_ids, input_mask, input_type_ids])
294
295
296
297
298
299
300

  squad_logits_layer = BertSquadLogitsLayer(
      initializer=initializer, float_type=float_type, name='squad_logits')
  start_logits, end_logits = squad_logits_layer(sequence_output)

  squad = tf.keras.Model(
      inputs={
301
          'input_word_ids': input_word_ids,
302
          'input_mask': input_mask,
303
          'input_type_ids': input_type_ids,
304
      },
305
      outputs=[start_logits, end_logits],
306
307
308
309
310
311
312
313
      name='squad_model')
  return squad, core_model


def classifier_model(bert_config,
                     float_type,
                     num_labels,
                     max_seq_length,
314
315
                     final_layer_initializer=None,
                     hub_module_url=None):
316
317
318
319
320
321
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
Chen Chen's avatar
Chen Chen committed
322
323
    bert_config: BertConfig or AlbertConfig, the config defines the core
      BERT or ALBERT model.
324
325
326
327
328
    float_type: dtype, tf.float32 or tf.bfloat16.
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
329
    hub_module_url: TF-Hub path/url to Bert module.
330
331
332
333
334
335
336
337
338
339
340

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
341
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
342
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
Hongkun Yu's avatar
Hongkun Yu committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    return bert_classifier.BertClassifier(
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
  bert_model = hub.KerasLayer(hub_module_url, trainable=True)
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
357
358
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
359

360
361
362
363
364
365
366
367
368
369
370
371
372
  output = tf.keras.layers.Dense(
      num_labels,
      kernel_initializer=initializer,
      name='output',
      dtype=float_type)(
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model