bert_models.py 14.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
import gin
22
import tensorflow as tf
23
import tensorflow_hub as hub
24

25
from official.modeling import tf_utils
26
from official.nlp.albert import configs as albert_configs
27
from official.nlp.bert import configs
28
from official.nlp.modeling import models
Hongkun Yu's avatar
Hongkun Yu committed
29
from official.nlp.modeling import networks
30
31
32
33
34


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
35
  def __init__(self, vocab_size, **kwargs):
36
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
37
38
39
40
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
41
42

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
43
44
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
45
    """Adds metrics."""
46
47
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
48
49
50
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
51
52
53
54
55
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

56
57
58
59
60
61
62
63
64
65
66
67
68
    if sentence_labels is not None:
      next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
          sentence_labels, sentence_output)
      self.add_metric(
          next_sentence_accuracy,
          name='next_sentence_accuracy',
          aggregation='mean')

    if next_sentence_loss is not None:
      self.add_metric(
          next_sentence_loss, name='next_sentence_loss', aggregation='mean')

  def call(self,
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
69
70
           lm_output_logits,
           sentence_output_logits,
71
72
73
           lm_label_ids,
           lm_label_weights,
           sentence_labels=None):
74
    """Implements call() for the layer."""
75
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
76
    lm_output_logits = tf.cast(lm_output_logits, tf.float32)
Chen Chen's avatar
Chen Chen committed
77

Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
78
79
80
81
82
83
    lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
        lm_label_ids, lm_output_logits, from_logits=True)
    lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
    lm_denominator_loss = tf.reduce_sum(lm_label_weights)
    mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
                                            lm_denominator_loss)
84
85

    if sentence_labels is not None:
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
86
87
88
89
      sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
      sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
          sentence_labels, sentence_output_logits, from_logits=True)
      sentence_loss = tf.reduce_mean(sentence_loss)
90
91
92
93
94
95
      loss = mask_label_loss + sentence_loss
    else:
      sentence_loss = None
      loss = mask_label_loss

    batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
96
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
97
    final_loss = tf.fill(batch_shape, loss)
98

Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
99
100
    self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
                      mask_label_loss, sentence_output_logits, sentence_labels,
Chen Chen's avatar
Chen Chen committed
101
                      sentence_loss)
102
103
104
    return final_loss


Hongkun Yu's avatar
Hongkun Yu committed
105
106
@gin.configurable
def get_transformer_encoder(bert_config,
Hongkun Yu's avatar
Hongkun Yu committed
107
                            sequence_length=None,
108
109
                            transformer_encoder_cls=None,
                            output_range=None):
110
111
112
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
113
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
Hongkun Yu's avatar
Hongkun Yu committed
114
    sequence_length: [Deprecated].
Hongkun Yu's avatar
Hongkun Yu committed
115
116
    transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
      default BERT encoder implementation.
117
118
    output_range: the sequence output range, [0, output_range). Default setting
      is to return the entire sequence output.
119
120
121
122

  Returns:
    A networks.TransformerEncoder object.
  """
Hongkun Yu's avatar
Hongkun Yu committed
123
  del sequence_length
Hongkun Yu's avatar
Hongkun Yu committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  if transformer_encoder_cls is not None:
    # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
    embedding_cfg = dict(
        vocab_size=bert_config.vocab_size,
        type_vocab_size=bert_config.type_vocab_size,
        hidden_size=bert_config.hidden_size,
        max_seq_length=bert_config.max_position_embeddings,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
        dropout_rate=bert_config.hidden_dropout_prob,
    )
    hidden_cfg = dict(
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size=bert_config.intermediate_size,
        intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
        dropout_rate=bert_config.hidden_dropout_prob,
        attention_dropout_rate=bert_config.attention_probs_dropout_prob,
Chen Chen's avatar
Chen Chen committed
141
142
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
Hongkun Yu's avatar
Hongkun Yu committed
143
    )
144
145
146
147
148
    kwargs = dict(
        embedding_cfg=embedding_cfg,
        hidden_cfg=hidden_cfg,
        num_hidden_instances=bert_config.num_hidden_layers,
        pooled_output_dim=bert_config.hidden_size,
Chen Chen's avatar
Chen Chen committed
149
150
        pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range))
Hongkun Yu's avatar
Hongkun Yu committed
151
152
153
154

    # Relies on gin configuration to define the Transformer encoder arguments.
    return transformer_encoder_cls(**kwargs)

Chen Chen's avatar
Chen Chen committed
155
  kwargs = dict(
156
157
158
159
160
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
161
      activation=tf_utils.get_activation(bert_config.hidden_act),
162
163
164
165
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
166
      embedding_width=bert_config.embedding_size,
167
      initializer=tf.keras.initializers.TruncatedNormal(
Zongwei Zhou's avatar
Zongwei Zhou committed
168
          stddev=bert_config.initializer_range))
169
  if isinstance(bert_config, albert_configs.AlbertConfig):
Chen Chen's avatar
Chen Chen committed
170
171
    return networks.AlbertTransformerEncoder(**kwargs)
  else:
172
    assert isinstance(bert_config, configs.BertConfig)
173
    kwargs['output_range'] = output_range
Chen Chen's avatar
Chen Chen committed
174
    return networks.TransformerEncoder(**kwargs)
175
176


177
178
179
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
180
                   initializer=None,
Chen Chen's avatar
Chen Chen committed
181
182
                   use_next_sentence_label=True,
                   return_core_pretrainer_model=False):
183
184
185
186
187
188
189
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
190
      initializer: Initializer for weights in BertPretrainer.
191
      use_next_sentence_label: Whether to use the next sentence label.
Chen Chen's avatar
Chen Chen committed
192
193
      return_core_pretrainer_model: Whether to also return the `BertPretrainer`
        object.
194
195

  Returns:
Chen Chen's avatar
Chen Chen committed
196
197
198
      A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
      save weights after pretraining, and (3) optional core `BertPretrainer`
      object if argument `return_core_pretrainer_model` is True.
199
200
201
202
203
204
205
206
207
208
209
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
210
211
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
212
213
214
215
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
216
217
218
219
220
221

  if use_next_sentence_label:
    next_sentence_labels = tf.keras.layers.Input(
        shape=(1,), name='next_sentence_labels', dtype=tf.int32)
  else:
    next_sentence_labels = None
222

Chen Chen's avatar
Chen Chen committed
223
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
224
225
226
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
227
  pretrainer_model = models.BertPretrainer(
Chen Chen's avatar
Chen Chen committed
228
      network=transformer_encoder,
229
      embedding_table=transformer_encoder.get_embedding_table(),
Chen Chen's avatar
Chen Chen committed
230
      num_classes=2,  # The next sentence prediction label has two classes.
Hongkun Yu's avatar
Hongkun Yu committed
231
      activation=tf_utils.get_activation(bert_config.hidden_act),
Chen Chen's avatar
Chen Chen committed
232
      num_token_predictions=max_predictions_per_seq,
233
      initializer=initializer,
Jeremiah Harmsen's avatar
Jeremiah Harmsen committed
234
      output='logits')
235

Hongkun Yu's avatar
Hongkun Yu committed
236
  outputs = pretrainer_model(
Chen Chen's avatar
Chen Chen committed
237
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
Hongkun Yu's avatar
Hongkun Yu committed
238
239
  lm_output = outputs['masked_lm']
  sentence_output = outputs['classification']
Chen Chen's avatar
Chen Chen committed
240
241
  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
242
243
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
244
245
246
247
248
249
250
251
252
253
254
255
  inputs = {
      'input_word_ids': input_word_ids,
      'input_mask': input_mask,
      'input_type_ids': input_type_ids,
      'masked_lm_positions': masked_lm_positions,
      'masked_lm_ids': masked_lm_ids,
      'masked_lm_weights': masked_lm_weights,
  }
  if use_next_sentence_label:
    inputs['next_sentence_labels'] = next_sentence_labels

  keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
Chen Chen's avatar
Chen Chen committed
256
257
258
259
  if return_core_pretrainer_model:
    return keras_model, transformer_encoder, pretrainer_model
  else:
    return keras_model, transformer_encoder
260
261


Hongkun Yu's avatar
Hongkun Yu committed
262
263
264
def squad_model(bert_config,
                max_seq_length,
                initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
265
266
                hub_module_url=None,
                hub_module_trainable=True):
267
268
269
270
271
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
Chen Chen's avatar
Chen Chen committed
272
273
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
274
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
275
    hub_module_trainable: True to finetune layers in the hub module.
276
277

  Returns:
278
279
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
280
  """
Chen Chen's avatar
Chen Chen committed
281
282
283
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
284
  if not hub_module_url:
Zongwei Zhou's avatar
Zongwei Zhou committed
285
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
286
    return models.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
287
        network=bert_encoder, initializer=initializer), bert_encoder
288

289
  input_word_ids = tf.keras.layers.Input(
290
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
291
292
293
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
294
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
295
  core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
296
  pooled_output, sequence_output = core_model(
Chen Chen's avatar
Chen Chen committed
297
      [input_word_ids, input_mask, input_type_ids])
298
  bert_encoder = tf.keras.Model(
299
      inputs={
300
          'input_word_ids': input_word_ids,
301
          'input_mask': input_mask,
302
          'input_type_ids': input_type_ids,
303
      },
304
305
      outputs=[sequence_output, pooled_output],
      name='core_model')
306
  return models.BertSpanLabeler(
307
      network=bert_encoder, initializer=initializer), bert_encoder
308
309
310
311


def classifier_model(bert_config,
                     num_labels,
Hongkun Yu's avatar
Hongkun Yu committed
312
                     max_seq_length=None,
313
                     final_layer_initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
314
315
                     hub_module_url=None,
                     hub_module_trainable=True):
316
317
318
319
320
321
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
322
323
    bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
      ALBERT model.
324
325
326
327
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
328
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
329
    hub_module_trainable: True to finetune layers in the hub module.
330
331
332
333
334
335
336
337
338
339
340

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
341
  if not hub_module_url:
342
343
    bert_encoder = get_transformer_encoder(
        bert_config, max_seq_length, output_range=1)
344
    return models.BertClassifier(
Hongkun Yu's avatar
Hongkun Yu committed
345
346
347
348
349
350
351
352
353
354
355
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
356
  bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
Hongkun Yu's avatar
Hongkun Yu committed
357
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
358
359
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
360

361
  output = tf.keras.layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
362
      num_labels, kernel_initializer=initializer, name='output')(
363
364
365
366
367
368
369
370
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model