bert_models.py 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
22
import tensorflow_hub as hub
23

24
from official.modeling import tf_utils
Chen Chen's avatar
Chen Chen committed
25
from official.nlp import bert_modeling
26
from official.nlp.albert import configs as albert_configs
Chen Chen's avatar
Chen Chen committed
27
from official.nlp.modeling import losses
Hongkun Yu's avatar
Hongkun Yu committed
28
29
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
Chen Chen's avatar
Chen Chen committed
30
from official.nlp.modeling.networks import bert_pretrainer
31
from official.nlp.modeling.networks import bert_span_labeler
32
33
34
35
36


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
37
  def __init__(self, vocab_size, **kwargs):
38
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
39
40
41
42
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
43
44

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
45
46
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
47
    """Adds metrics."""
48
49
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
50
51
52
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
53
54
55
56
57
58
59
60
61
62
63
64
65
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

    next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        sentence_labels, sentence_output)
    self.add_metric(
        next_sentence_accuracy,
        name='next_sentence_accuracy',
        aggregation='mean')

    self.add_metric(
Chen Chen's avatar
Chen Chen committed
66
        next_sentence_loss, name='next_sentence_loss', aggregation='mean')
67

68
69
  def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
           sentence_labels):
70
    """Implements call() for the layer."""
71
    lm_label_weights = tf.keras.backend.cast(lm_label_weights, tf.float32)
Chen Chen's avatar
Chen Chen committed
72
73
74
75
76

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
    sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=sentence_labels, predictions=sentence_output)
77
    loss = mask_label_loss + sentence_loss
Chen Chen's avatar
Chen Chen committed
78
    batch_shape = tf.slice(tf.keras.backend.shape(sentence_labels), [0], [1])
79
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
80
    final_loss = tf.fill(batch_shape, loss)
81
82

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
83
84
                      mask_label_loss, sentence_output, sentence_labels,
                      sentence_loss)
85
86
87
    return final_loss


Chen Chen's avatar
Chen Chen committed
88
def get_transformer_encoder(bert_config,
Zongwei Zhou's avatar
Zongwei Zhou committed
89
                            sequence_length):
90
91
92
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
93
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
94
95
96
97
98
    sequence_length: Maximum sequence length of the training data.

  Returns:
    A networks.TransformerEncoder object.
  """
Chen Chen's avatar
Chen Chen committed
99
  kwargs = dict(
100
101
102
103
104
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
105
      activation=tf_utils.get_activation(bert_config.hidden_act),
106
107
108
109
110
111
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=sequence_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
Zongwei Zhou's avatar
Zongwei Zhou committed
112
          stddev=bert_config.initializer_range))
113
  if isinstance(bert_config, albert_configs.AlbertConfig):
Chen Chen's avatar
Chen Chen committed
114
115
116
117
118
    kwargs['embedding_width'] = bert_config.embedding_size
    return networks.AlbertTransformerEncoder(**kwargs)
  else:
    assert isinstance(bert_config, bert_modeling.BertConfig)
    return networks.TransformerEncoder(**kwargs)
119
120


121
122
123
124
125
126
127
128
129
130
131
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
132
      initializer: Initializer for weights in BertPretrainer.
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
148
149
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
150
151
152
153
154
155
156
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
  next_sentence_labels = tf.keras.layers.Input(
      shape=(1,), name='next_sentence_labels', dtype=tf.int32)

Chen Chen's avatar
Chen Chen committed
157
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
158
159
160
161
162
163
164
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
  pretrainer_model = bert_pretrainer.BertPretrainer(
      network=transformer_encoder,
      num_classes=2,  # The next sentence prediction label has two classes.
      num_token_predictions=max_predictions_per_seq,
165
      initializer=initializer,
Chen Chen's avatar
Chen Chen committed
166
      output='predictions')
167

Chen Chen's avatar
Chen Chen committed
168
169
170
171
172
  lm_output, sentence_output = pretrainer_model(
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
173
174
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
Chen Chen's avatar
Chen Chen committed
175
  keras_model = tf.keras.Model(
176
177
178
179
180
181
182
183
184
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids,
          'masked_lm_positions': masked_lm_positions,
          'masked_lm_ids': masked_lm_ids,
          'masked_lm_weights': masked_lm_weights,
          'next_sentence_labels': next_sentence_labels,
      },
Chen Chen's avatar
Chen Chen committed
185
186
      outputs=output_loss)
  return keras_model, transformer_encoder
187
188
189
190
191


class BertSquadLogitsLayer(tf.keras.layers.Layer):
  """Returns a layer that computes custom logits for BERT squad model."""

Zongwei Zhou's avatar
Zongwei Zhou committed
192
  def __init__(self, initializer=None, **kwargs):
193
194
195
196
    super(BertSquadLogitsLayer, self).__init__(**kwargs)
    self.initializer = initializer

  def build(self, unused_input_shapes):
197
    """Implements build() for the layer."""
198
199
200
201
202
    self.final_dense = tf.keras.layers.Dense(
        units=2, kernel_initializer=self.initializer, name='final_dense')
    super(BertSquadLogitsLayer, self).build(unused_input_shapes)

  def call(self, inputs):
203
    """Implements call() for the layer."""
204
205
    sequence_output = inputs

206
207
    input_shape = tf_utils.get_shape_list(
        sequence_output, name='sequence_output_tensor')
208
209
210
211
212
213
214
215
216
217
218
219
    sequence_length = input_shape[1]
    num_hidden_units = input_shape[2]

    final_hidden_input = tf.keras.backend.reshape(sequence_output,
                                                  [-1, num_hidden_units])
    logits = self.final_dense(final_hidden_input)
    logits = tf.keras.backend.reshape(logits, [-1, sequence_length, 2])
    logits = tf.transpose(logits, [2, 0, 1])
    unstacked_logits = tf.unstack(logits, axis=0)
    return unstacked_logits[0], unstacked_logits[1]


Hongkun Yu's avatar
Hongkun Yu committed
220
221
222
def squad_model(bert_config,
                max_seq_length,
                initializer=None,
Chen Chen's avatar
Chen Chen committed
223
                hub_module_url=None):
224
225
226
227
228
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
Chen Chen's avatar
Chen Chen committed
229
230
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
231
    hub_module_url: TF-Hub path/url to Bert module.
232
233

  Returns:
234
235
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
236
  """
Chen Chen's avatar
Chen Chen committed
237
238
239
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
240
  if not hub_module_url:
Zongwei Zhou's avatar
Zongwei Zhou committed
241
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
242
    return bert_span_labeler.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
243
        network=bert_encoder, initializer=initializer), bert_encoder
244

245
  input_word_ids = tf.keras.layers.Input(
246
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
247
248
249
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
250
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
Chen Chen's avatar
Chen Chen committed
251
252
253
  core_model = hub.KerasLayer(hub_module_url, trainable=True)
  _, sequence_output = core_model(
      [input_word_ids, input_mask, input_type_ids])
254
255

  squad_logits_layer = BertSquadLogitsLayer(
Zongwei Zhou's avatar
Zongwei Zhou committed
256
      initializer=initializer, name='squad_logits')
257
258
259
260
  start_logits, end_logits = squad_logits_layer(sequence_output)

  squad = tf.keras.Model(
      inputs={
261
          'input_word_ids': input_word_ids,
262
          'input_mask': input_mask,
263
          'input_type_ids': input_type_ids,
264
      },
265
      outputs=[start_logits, end_logits],
266
267
268
269
270
271
272
      name='squad_model')
  return squad, core_model


def classifier_model(bert_config,
                     num_labels,
                     max_seq_length,
273
274
                     final_layer_initializer=None,
                     hub_module_url=None):
275
276
277
278
279
280
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
Chen Chen's avatar
Chen Chen committed
281
282
    bert_config: BertConfig or AlbertConfig, the config defines the core
      BERT or ALBERT model.
283
284
285
286
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
287
    hub_module_url: TF-Hub path/url to Bert module.
288
289
290
291
292
293
294
295
296
297
298

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
299
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
300
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
Hongkun Yu's avatar
Hongkun Yu committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    return bert_classifier.BertClassifier(
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
  bert_model = hub.KerasLayer(hub_module_url, trainable=True)
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
315
316
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
317

318
319
320
  output = tf.keras.layers.Dense(
      num_labels,
      kernel_initializer=initializer,
Zongwei Zhou's avatar
Zongwei Zhou committed
321
      name='output')(
322
323
324
325
326
327
328
329
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model