bert_models.py 13.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
22
import tensorflow_hub as hub
23

24
from official.modeling import tf_utils
Chen Chen's avatar
Chen Chen committed
25
from official.nlp import bert_modeling
Chen Chen's avatar
Chen Chen committed
26
from official.nlp.modeling import losses
Hongkun Yu's avatar
Hongkun Yu committed
27
28
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
Chen Chen's avatar
Chen Chen committed
29
from official.nlp.modeling.networks import bert_pretrainer
30
from official.nlp.modeling.networks import bert_span_labeler
31
32
33
34
35


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
36
  def __init__(self, vocab_size, **kwargs):
37
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
38
39
40
41
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
42
43

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
44
45
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
46
    """Adds metrics."""
47
48
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
49
50
51
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
52
53
54
55
56
57
58
59
60
61
62
63
64
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

    next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        sentence_labels, sentence_output)
    self.add_metric(
        next_sentence_accuracy,
        name='next_sentence_accuracy',
        aggregation='mean')

    self.add_metric(
Chen Chen's avatar
Chen Chen committed
65
        next_sentence_loss, name='next_sentence_loss', aggregation='mean')
66

67
68
  def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
           sentence_labels):
69
    """Implements call() for the layer."""
70
    lm_label_weights = tf.keras.backend.cast(lm_label_weights, tf.float32)
Chen Chen's avatar
Chen Chen committed
71
72
73
74
75

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
    sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=sentence_labels, predictions=sentence_output)
76
    loss = mask_label_loss + sentence_loss
Chen Chen's avatar
Chen Chen committed
77
    batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0], [1])
78
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
79
    final_loss = tf.fill(batch_shape, loss)
80
81

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
82
83
                      mask_label_loss, sentence_output, sentence_labels,
                      sentence_loss)
84
85
86
    return final_loss


Chen Chen's avatar
Chen Chen committed
87
88
89
def get_transformer_encoder(bert_config,
                            sequence_length,
                            float_dtype=tf.float32):
90
91
92
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
93
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
94
    sequence_length: Maximum sequence length of the training data.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
95
    float_dtype: tf.dtype, tf.float32 or tf.float16.
96
97
98
99

  Returns:
    A networks.TransformerEncoder object.
  """
Chen Chen's avatar
Chen Chen committed
100
  kwargs = dict(
101
102
103
104
105
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
106
      activation=tf_utils.get_activation(bert_config.hidden_act),
107
108
109
110
111
112
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=sequence_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
113
114
          stddev=bert_config.initializer_range),
      float_dtype=float_dtype.name)
Chen Chen's avatar
Chen Chen committed
115
116
117
118
119
120
  if isinstance(bert_config, bert_modeling.AlbertConfig):
    kwargs['embedding_width'] = bert_config.embedding_size
    return networks.AlbertTransformerEncoder(**kwargs)
  else:
    assert isinstance(bert_config, bert_modeling.BertConfig)
    return networks.TransformerEncoder(**kwargs)
121
122


123
124
125
126
127
128
129
130
131
132
133
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
134
      initializer: Initializer for weights in BertPretrainer.
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
150
151
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
152
153
154
155
156
157
158
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
  next_sentence_labels = tf.keras.layers.Input(
      shape=(1,), name='next_sentence_labels', dtype=tf.int32)

Chen Chen's avatar
Chen Chen committed
159
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
160
161
162
163
164
165
166
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
  pretrainer_model = bert_pretrainer.BertPretrainer(
      network=transformer_encoder,
      num_classes=2,  # The next sentence prediction label has two classes.
      num_token_predictions=max_predictions_per_seq,
167
      initializer=initializer,
Chen Chen's avatar
Chen Chen committed
168
      output='predictions')
169

Chen Chen's avatar
Chen Chen committed
170
171
172
173
174
  lm_output, sentence_output = pretrainer_model(
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
175
176
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
Chen Chen's avatar
Chen Chen committed
177
  keras_model = tf.keras.Model(
178
179
180
181
182
183
184
185
186
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids,
          'masked_lm_positions': masked_lm_positions,
          'masked_lm_ids': masked_lm_ids,
          'masked_lm_weights': masked_lm_weights,
          'next_sentence_labels': next_sentence_labels,
      },
Chen Chen's avatar
Chen Chen committed
187
188
      outputs=output_loss)
  return keras_model, transformer_encoder
189
190
191
192
193
194
195
196
197
198
199


class BertSquadLogitsLayer(tf.keras.layers.Layer):
  """Returns a layer that computes custom logits for BERT squad model."""

  def __init__(self, initializer=None, float_type=tf.float32, **kwargs):
    super(BertSquadLogitsLayer, self).__init__(**kwargs)
    self.initializer = initializer
    self.float_type = float_type

  def build(self, unused_input_shapes):
200
    """Implements build() for the layer."""
201
202
203
204
205
    self.final_dense = tf.keras.layers.Dense(
        units=2, kernel_initializer=self.initializer, name='final_dense')
    super(BertSquadLogitsLayer, self).build(unused_input_shapes)

  def call(self, inputs):
206
    """Implements call() for the layer."""
207
208
    sequence_output = inputs

209
210
    input_shape = tf_utils.get_shape_list(
        sequence_output, name='sequence_output_tensor')
211
212
213
214
215
216
217
218
219
    sequence_length = input_shape[1]
    num_hidden_units = input_shape[2]

    final_hidden_input = tf.keras.backend.reshape(sequence_output,
                                                  [-1, num_hidden_units])
    logits = self.final_dense(final_hidden_input)
    logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2])
    logits = tf.transpose(logits, [2, 0, 1])
    unstacked_logits = tf.unstack(logits, axis=0)
220
221
    if self.float_type == tf.float16:
      unstacked_logits = tf.cast(unstacked_logits, tf.float32)
222
223
224
    return unstacked_logits[0], unstacked_logits[1]


Hongkun Yu's avatar
Hongkun Yu committed
225
226
227
228
def squad_model(bert_config,
                max_seq_length,
                float_type,
                initializer=None,
Chen Chen's avatar
Chen Chen committed
229
                hub_module_url=None):
230
231
232
233
234
235
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
    float_type: tf.dtype, tf.float32 or tf.bfloat16.
Chen Chen's avatar
Chen Chen committed
236
237
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
238
    hub_module_url: TF-Hub path/url to Bert module.
239
240

  Returns:
241
242
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
243
  """
Chen Chen's avatar
Chen Chen committed
244
245
246
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
247
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
248
249
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length,
                                           float_type)
250
    return bert_span_labeler.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
251
        network=bert_encoder, initializer=initializer), bert_encoder
252

253
  input_word_ids = tf.keras.layers.Input(
254
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
255
256
257
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
258
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
Chen Chen's avatar
Chen Chen committed
259
260
261
  core_model = hub.KerasLayer(hub_module_url, trainable=True)
  _, sequence_output = core_model(
      [input_word_ids, input_mask, input_type_ids])
262
263
264
265
266
267
268

  squad_logits_layer = BertSquadLogitsLayer(
      initializer=initializer, float_type=float_type, name='squad_logits')
  start_logits, end_logits = squad_logits_layer(sequence_output)

  squad = tf.keras.Model(
      inputs={
269
          'input_word_ids': input_word_ids,
270
          'input_mask': input_mask,
271
          'input_type_ids': input_type_ids,
272
      },
273
      outputs=[start_logits, end_logits],
274
275
276
277
278
279
280
281
      name='squad_model')
  return squad, core_model


def classifier_model(bert_config,
                     float_type,
                     num_labels,
                     max_seq_length,
282
283
                     final_layer_initializer=None,
                     hub_module_url=None):
284
285
286
287
288
289
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
Chen Chen's avatar
Chen Chen committed
290
291
    bert_config: BertConfig or AlbertConfig, the config defines the core
      BERT or ALBERT model.
292
293
294
295
296
    float_type: dtype, tf.float32 or tf.bfloat16.
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
297
    hub_module_url: TF-Hub path/url to Bert module.
298
299
300
301
302
303
304
305
306
307
308

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
309
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
310
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
Hongkun Yu's avatar
Hongkun Yu committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    return bert_classifier.BertClassifier(
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
  bert_model = hub.KerasLayer(hub_module_url, trainable=True)
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
325
326
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
327

328
329
330
331
332
333
334
335
336
337
338
339
340
  output = tf.keras.layers.Dense(
      num_labels,
      kernel_initializer=initializer,
      name='output',
      dtype=float_type)(
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model