bert_models.py 13.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
import gin
22
import tensorflow as tf
23
import tensorflow_hub as hub
24

25
from official.modeling import tf_utils
26
from official.nlp.albert import configs as albert_configs
27
from official.nlp.bert import configs
Chen Chen's avatar
Chen Chen committed
28
from official.nlp.modeling import losses
29
from official.nlp.modeling import models
Hongkun Yu's avatar
Hongkun Yu committed
30
from official.nlp.modeling import networks
31
32
33
34
35


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
36
  def __init__(self, vocab_size, **kwargs):
37
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
38
39
40
41
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
42
43

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
44
45
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
46
    """Adds metrics."""
47
48
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
49
50
51
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
52
53
54
55
56
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    if sentence_labels is not None:
      next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
          sentence_labels, sentence_output)
      self.add_metric(
          next_sentence_accuracy,
          name='next_sentence_accuracy',
          aggregation='mean')

    if next_sentence_loss is not None:
      self.add_metric(
          next_sentence_loss, name='next_sentence_loss', aggregation='mean')

  def call(self,
           lm_output,
           sentence_output,
           lm_label_ids,
           lm_label_weights,
           sentence_labels=None):
75
    """Implements call() for the layer."""
76
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
77
    lm_output = tf.cast(lm_output, tf.float32)
Chen Chen's avatar
Chen Chen committed
78
79
80

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
81
82
83
84
85
86
87
88
89
90
91

    if sentence_labels is not None:
      sentence_output = tf.cast(sentence_output, tf.float32)
      sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
          labels=sentence_labels, predictions=sentence_output)
      loss = mask_label_loss + sentence_loss
    else:
      sentence_loss = None
      loss = mask_label_loss

    batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
92
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
93
    final_loss = tf.fill(batch_shape, loss)
94
95

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
96
97
                      mask_label_loss, sentence_output, sentence_labels,
                      sentence_loss)
98
99
100
    return final_loss


Hongkun Yu's avatar
Hongkun Yu committed
101
102
103
104
@gin.configurable
def get_transformer_encoder(bert_config,
                            sequence_length,
                            transformer_encoder_cls=None):
105
106
107
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
108
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
109
    sequence_length: Maximum sequence length of the training data.
Hongkun Yu's avatar
Hongkun Yu committed
110
111
    transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
      default BERT encoder implementation.
112
113
114
115

  Returns:
    A networks.TransformerEncoder object.
  """
Hongkun Yu's avatar
Hongkun Yu committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  if transformer_encoder_cls is not None:
    # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
    embedding_cfg = dict(
        vocab_size=bert_config.vocab_size,
        type_vocab_size=bert_config.type_vocab_size,
        hidden_size=bert_config.hidden_size,
        seq_length=sequence_length,
        max_seq_length=bert_config.max_position_embeddings,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
        dropout_rate=bert_config.hidden_dropout_prob,
    )
    hidden_cfg = dict(
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size=bert_config.intermediate_size,
        intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
        dropout_rate=bert_config.hidden_dropout_prob,
        attention_dropout_rate=bert_config.attention_probs_dropout_prob,
    )
135
136
137
138
139
140
    kwargs = dict(
        embedding_cfg=embedding_cfg,
        hidden_cfg=hidden_cfg,
        num_hidden_instances=bert_config.num_hidden_layers,
        pooled_output_dim=bert_config.hidden_size,
    )
Hongkun Yu's avatar
Hongkun Yu committed
141
142
143
144

    # Relies on gin configuration to define the Transformer encoder arguments.
    return transformer_encoder_cls(**kwargs)

Chen Chen's avatar
Chen Chen committed
145
  kwargs = dict(
146
147
148
149
150
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
151
      activation=tf_utils.get_activation(bert_config.hidden_act),
152
153
154
155
156
157
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=sequence_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
Zongwei Zhou's avatar
Zongwei Zhou committed
158
          stddev=bert_config.initializer_range))
159
  if isinstance(bert_config, albert_configs.AlbertConfig):
Chen Chen's avatar
Chen Chen committed
160
161
162
    kwargs['embedding_width'] = bert_config.embedding_size
    return networks.AlbertTransformerEncoder(**kwargs)
  else:
163
    assert isinstance(bert_config, configs.BertConfig)
Chen Chen's avatar
Chen Chen committed
164
    return networks.TransformerEncoder(**kwargs)
165
166


167
168
169
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
170
171
                   initializer=None,
                   use_next_sentence_label=True):
172
173
174
175
176
177
178
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
179
      initializer: Initializer for weights in BertPretrainer.
180
      use_next_sentence_label: Whether to use the next sentence label.
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
196
197
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
198
199
200
201
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
202
203
204
205
206
207

  if use_next_sentence_label:
    next_sentence_labels = tf.keras.layers.Input(
        shape=(1,), name='next_sentence_labels', dtype=tf.int32)
  else:
    next_sentence_labels = None
208

Chen Chen's avatar
Chen Chen committed
209
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
210
211
212
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
213
  pretrainer_model = models.BertPretrainer(
Chen Chen's avatar
Chen Chen committed
214
215
216
      network=transformer_encoder,
      num_classes=2,  # The next sentence prediction label has two classes.
      num_token_predictions=max_predictions_per_seq,
217
      initializer=initializer,
Chen Chen's avatar
Chen Chen committed
218
      output='predictions')
219

Chen Chen's avatar
Chen Chen committed
220
221
222
223
224
  lm_output, sentence_output = pretrainer_model(
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
225
226
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
227
228
229
230
231
232
233
234
235
236
237
238
  inputs = {
      'input_word_ids': input_word_ids,
      'input_mask': input_mask,
      'input_type_ids': input_type_ids,
      'masked_lm_positions': masked_lm_positions,
      'masked_lm_ids': masked_lm_ids,
      'masked_lm_weights': masked_lm_weights,
  }
  if use_next_sentence_label:
    inputs['next_sentence_labels'] = next_sentence_labels

  keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
Chen Chen's avatar
Chen Chen committed
239
  return keras_model, transformer_encoder
240
241


Hongkun Yu's avatar
Hongkun Yu committed
242
243
244
def squad_model(bert_config,
                max_seq_length,
                initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
245
246
                hub_module_url=None,
                hub_module_trainable=True):
247
248
249
250
251
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
Chen Chen's avatar
Chen Chen committed
252
253
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
254
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
255
    hub_module_trainable: True to finetune layers in the hub module.
256
257

  Returns:
258
259
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
260
  """
Chen Chen's avatar
Chen Chen committed
261
262
263
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
264
  if not hub_module_url:
Zongwei Zhou's avatar
Zongwei Zhou committed
265
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
266
    return models.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
267
        network=bert_encoder, initializer=initializer), bert_encoder
268

269
  input_word_ids = tf.keras.layers.Input(
270
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
271
272
273
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
274
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
275
  core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
276
  pooled_output, sequence_output = core_model(
Chen Chen's avatar
Chen Chen committed
277
      [input_word_ids, input_mask, input_type_ids])
278
  bert_encoder = tf.keras.Model(
279
      inputs={
280
          'input_word_ids': input_word_ids,
281
          'input_mask': input_mask,
282
          'input_type_ids': input_type_ids,
283
      },
284
285
      outputs=[sequence_output, pooled_output],
      name='core_model')
286
  return models.BertSpanLabeler(
287
      network=bert_encoder, initializer=initializer), bert_encoder
288
289
290
291
292


def classifier_model(bert_config,
                     num_labels,
                     max_seq_length,
293
                     final_layer_initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
294
295
                     hub_module_url=None,
                     hub_module_trainable=True):
296
297
298
299
300
301
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
302
303
    bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
      ALBERT model.
304
305
306
307
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
308
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
309
    hub_module_trainable: True to finetune layers in the hub module.
310
311
312
313
314
315
316
317
318
319
320

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
321
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
322
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
323
    return models.BertClassifier(
Hongkun Yu's avatar
Hongkun Yu committed
324
325
326
327
328
329
330
331
332
333
334
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
335
  bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
Hongkun Yu's avatar
Hongkun Yu committed
336
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
337
338
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
339

340
  output = tf.keras.layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
341
      num_labels, kernel_initializer=initializer, name='output')(
342
343
344
345
346
347
348
349
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model