bert_models.py 14.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

Hongkun Yu's avatar
Hongkun Yu committed
17
import gin
18
import tensorflow as tf
19
import tensorflow_hub as hub
20

21
from official.modeling import tf_utils
22
from official.nlp.albert import configs as albert_configs
23
from official.nlp.bert import configs
24
from official.nlp.modeling import models
Hongkun Yu's avatar
Hongkun Yu committed
25
from official.nlp.modeling import networks
26
27
28
29
30


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
31
  def __init__(self, vocab_size, **kwargs):
32
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
33
34
35
36
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
37
38

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
39
40
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
41
    """Adds metrics."""
42
43
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
44
45
46
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
47
48
49
50
51
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

52
53
54
55
56
57
58
59
60
61
62
63
64
    if sentence_labels is not None:
      next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
          sentence_labels, sentence_output)
      self.add_metric(
          next_sentence_accuracy,
          name='next_sentence_accuracy',
          aggregation='mean')

    if next_sentence_loss is not None:
      self.add_metric(
          next_sentence_loss, name='next_sentence_loss', aggregation='mean')

  def call(self,
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
65
66
           lm_output_logits,
           sentence_output_logits,
67
68
69
           lm_label_ids,
           lm_label_weights,
           sentence_labels=None):
70
    """Implements call() for the layer."""
71
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
72
    lm_output_logits = tf.cast(lm_output_logits, tf.float32)
Chen Chen's avatar
Chen Chen committed
73

Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
74
75
76
77
78
79
    lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
        lm_label_ids, lm_output_logits, from_logits=True)
    lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
    lm_denominator_loss = tf.reduce_sum(lm_label_weights)
    mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
                                            lm_denominator_loss)
80
81

    if sentence_labels is not None:
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
82
83
84
85
      sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
      sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
          sentence_labels, sentence_output_logits, from_logits=True)
      sentence_loss = tf.reduce_mean(sentence_loss)
86
87
88
89
90
91
      loss = mask_label_loss + sentence_loss
    else:
      sentence_loss = None
      loss = mask_label_loss

    batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
92
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
93
    final_loss = tf.fill(batch_shape, loss)
94

Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
95
96
    self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
                      mask_label_loss, sentence_output_logits, sentence_labels,
Chen Chen's avatar
Chen Chen committed
97
                      sentence_loss)
98
99
100
    return final_loss


Hongkun Yu's avatar
Hongkun Yu committed
101
102
@gin.configurable
def get_transformer_encoder(bert_config,
Hongkun Yu's avatar
Hongkun Yu committed
103
                            sequence_length=None,
104
105
                            transformer_encoder_cls=None,
                            output_range=None):
106
107
108
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
109
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
Hongkun Yu's avatar
Hongkun Yu committed
110
    sequence_length: [Deprecated].
Hongkun Yu's avatar
Hongkun Yu committed
111
112
    transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
      default BERT encoder implementation.
113
114
    output_range: the sequence output range, [0, output_range). Default setting
      is to return the entire sequence output.
115
116

  Returns:
117
    A encoder object.
118
  """
Hongkun Yu's avatar
Hongkun Yu committed
119
  del sequence_length
Hongkun Yu's avatar
Hongkun Yu committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  if transformer_encoder_cls is not None:
    # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
    embedding_cfg = dict(
        vocab_size=bert_config.vocab_size,
        type_vocab_size=bert_config.type_vocab_size,
        hidden_size=bert_config.hidden_size,
        max_seq_length=bert_config.max_position_embeddings,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
        dropout_rate=bert_config.hidden_dropout_prob,
    )
    hidden_cfg = dict(
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size=bert_config.intermediate_size,
        intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
        dropout_rate=bert_config.hidden_dropout_prob,
        attention_dropout_rate=bert_config.attention_probs_dropout_prob,
Chen Chen's avatar
Chen Chen committed
137
138
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
Hongkun Yu's avatar
Hongkun Yu committed
139
    )
140
141
142
143
144
    kwargs = dict(
        embedding_cfg=embedding_cfg,
        hidden_cfg=hidden_cfg,
        num_hidden_instances=bert_config.num_hidden_layers,
        pooled_output_dim=bert_config.hidden_size,
Chen Chen's avatar
Chen Chen committed
145
146
        pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range))
Hongkun Yu's avatar
Hongkun Yu committed
147
148
149
150

    # Relies on gin configuration to define the Transformer encoder arguments.
    return transformer_encoder_cls(**kwargs)

Chen Chen's avatar
Chen Chen committed
151
  kwargs = dict(
152
153
154
155
156
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
157
      activation=tf_utils.get_activation(bert_config.hidden_act),
158
159
160
161
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
162
      embedding_width=bert_config.embedding_size,
163
      initializer=tf.keras.initializers.TruncatedNormal(
Zongwei Zhou's avatar
Zongwei Zhou committed
164
          stddev=bert_config.initializer_range))
165
  if isinstance(bert_config, albert_configs.AlbertConfig):
Chen Chen's avatar
Chen Chen committed
166
    return networks.AlbertEncoder(**kwargs)
Chen Chen's avatar
Chen Chen committed
167
  else:
168
    assert isinstance(bert_config, configs.BertConfig)
169
    kwargs['output_range'] = output_range
170
    return networks.BertEncoder(**kwargs)
171
172


173
174
175
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
176
                   initializer=None,
Chen Chen's avatar
Chen Chen committed
177
178
                   use_next_sentence_label=True,
                   return_core_pretrainer_model=False):
179
180
181
182
183
184
185
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
186
      initializer: Initializer for weights in BertPretrainer.
187
      use_next_sentence_label: Whether to use the next sentence label.
Chen Chen's avatar
Chen Chen committed
188
189
      return_core_pretrainer_model: Whether to also return the `BertPretrainer`
        object.
190
191

  Returns:
Chen Chen's avatar
Chen Chen committed
192
193
194
      A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
      save weights after pretraining, and (3) optional core `BertPretrainer`
      object if argument `return_core_pretrainer_model` is True.
195
196
197
198
199
200
201
202
203
204
205
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
206
207
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
208
209
210
211
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
212
213
214
215
216
217

  if use_next_sentence_label:
    next_sentence_labels = tf.keras.layers.Input(
        shape=(1,), name='next_sentence_labels', dtype=tf.int32)
  else:
    next_sentence_labels = None
218

Chen Chen's avatar
Chen Chen committed
219
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
220
221
222
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
223
  pretrainer_model = models.BertPretrainer(
Chen Chen's avatar
Chen Chen committed
224
      network=transformer_encoder,
225
      embedding_table=transformer_encoder.get_embedding_table(),
Chen Chen's avatar
Chen Chen committed
226
      num_classes=2,  # The next sentence prediction label has two classes.
Hongkun Yu's avatar
Hongkun Yu committed
227
      activation=tf_utils.get_activation(bert_config.hidden_act),
Chen Chen's avatar
Chen Chen committed
228
      num_token_predictions=max_predictions_per_seq,
229
      initializer=initializer,
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
230
      output='logits')
231

Hongkun Yu's avatar
Hongkun Yu committed
232
  outputs = pretrainer_model(
Chen Chen's avatar
Chen Chen committed
233
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
Hongkun Yu's avatar
Hongkun Yu committed
234
235
  lm_output = outputs['masked_lm']
  sentence_output = outputs['classification']
Chen Chen's avatar
Chen Chen committed
236
237
  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
238
239
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
240
241
242
243
244
245
246
247
248
249
250
251
  inputs = {
      'input_word_ids': input_word_ids,
      'input_mask': input_mask,
      'input_type_ids': input_type_ids,
      'masked_lm_positions': masked_lm_positions,
      'masked_lm_ids': masked_lm_ids,
      'masked_lm_weights': masked_lm_weights,
  }
  if use_next_sentence_label:
    inputs['next_sentence_labels'] = next_sentence_labels

  keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
Chen Chen's avatar
Chen Chen committed
252
253
254
255
  if return_core_pretrainer_model:
    return keras_model, transformer_encoder, pretrainer_model
  else:
    return keras_model, transformer_encoder
256
257


Hongkun Yu's avatar
Hongkun Yu committed
258
259
260
def squad_model(bert_config,
                max_seq_length,
                initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
261
262
                hub_module_url=None,
                hub_module_trainable=True):
263
264
265
266
267
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
Chen Chen's avatar
Chen Chen committed
268
269
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
270
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
271
    hub_module_trainable: True to finetune layers in the hub module.
272
273

  Returns:
274
275
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
276
  """
Chen Chen's avatar
Chen Chen committed
277
278
279
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
280
  if not hub_module_url:
Zongwei Zhou's avatar
Zongwei Zhou committed
281
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
282
    return models.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
283
        network=bert_encoder, initializer=initializer), bert_encoder
284

285
  input_word_ids = tf.keras.layers.Input(
286
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
287
288
289
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
290
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
291
  core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
292
  pooled_output, sequence_output = core_model(
Chen Chen's avatar
Chen Chen committed
293
      [input_word_ids, input_mask, input_type_ids])
294
  bert_encoder = tf.keras.Model(
295
      inputs={
296
          'input_word_ids': input_word_ids,
297
          'input_mask': input_mask,
298
          'input_type_ids': input_type_ids,
299
      },
300
301
      outputs=[sequence_output, pooled_output],
      name='core_model')
302
  return models.BertSpanLabeler(
303
      network=bert_encoder, initializer=initializer), bert_encoder
304
305
306
307


def classifier_model(bert_config,
                     num_labels,
Hongkun Yu's avatar
Hongkun Yu committed
308
                     max_seq_length=None,
309
                     final_layer_initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
310
311
                     hub_module_url=None,
                     hub_module_trainable=True):
312
313
314
315
316
317
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
318
319
    bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
      ALBERT model.
320
321
322
323
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
324
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
325
    hub_module_trainable: True to finetune layers in the hub module.
326
327
328
329
330
331
332
333
334
335
336

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
337
  if not hub_module_url:
338
339
    bert_encoder = get_transformer_encoder(
        bert_config, max_seq_length, output_range=1)
340
    return models.BertClassifier(
Hongkun Yu's avatar
Hongkun Yu committed
341
342
343
344
345
346
347
348
349
350
351
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
352
  bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
Hongkun Yu's avatar
Hongkun Yu committed
353
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
354
355
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
356

357
  output = tf.keras.layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
358
      num_labels, kernel_initializer=initializer, name='output')(
359
360
361
362
363
364
365
366
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model