bert_models.py 14.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
import gin
22
import tensorflow as tf
23
import tensorflow_hub as hub
24

25
from official.modeling import tf_utils
26
from official.nlp.albert import configs as albert_configs
27
from official.nlp.bert import configs
Chen Chen's avatar
Chen Chen committed
28
from official.nlp.modeling import losses
29
from official.nlp.modeling import models
Hongkun Yu's avatar
Hongkun Yu committed
30
from official.nlp.modeling import networks
31
32
33
34
35


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
36
  def __init__(self, vocab_size, **kwargs):
37
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
38
39
40
41
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
42
43

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
44
45
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
46
    """Adds metrics."""
47
48
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
49
50
51
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
52
53
54
55
56
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    if sentence_labels is not None:
      next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
          sentence_labels, sentence_output)
      self.add_metric(
          next_sentence_accuracy,
          name='next_sentence_accuracy',
          aggregation='mean')

    if next_sentence_loss is not None:
      self.add_metric(
          next_sentence_loss, name='next_sentence_loss', aggregation='mean')

  def call(self,
           lm_output,
           sentence_output,
           lm_label_ids,
           lm_label_weights,
           sentence_labels=None):
75
    """Implements call() for the layer."""
76
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
77
    lm_output = tf.cast(lm_output, tf.float32)
Chen Chen's avatar
Chen Chen committed
78
79
80

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
81
82
83
84
85
86
87
88
89
90
91

    if sentence_labels is not None:
      sentence_output = tf.cast(sentence_output, tf.float32)
      sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
          labels=sentence_labels, predictions=sentence_output)
      loss = mask_label_loss + sentence_loss
    else:
      sentence_loss = None
      loss = mask_label_loss

    batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
92
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
93
    final_loss = tf.fill(batch_shape, loss)
94
95

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
96
97
                      mask_label_loss, sentence_output, sentence_labels,
                      sentence_loss)
98
99
100
    return final_loss


Hongkun Yu's avatar
Hongkun Yu committed
101
102
103
@gin.configurable
def get_transformer_encoder(bert_config,
                            sequence_length,
104
105
                            transformer_encoder_cls=None,
                            output_range=None):
106
107
108
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
109
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
110
    sequence_length: Maximum sequence length of the training data.
Hongkun Yu's avatar
Hongkun Yu committed
111
112
    transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
      default BERT encoder implementation.
113
114
    output_range: the sequence output range, [0, output_range). Default setting
      is to return the entire sequence output.
115
116
117
118

  Returns:
    A networks.TransformerEncoder object.
  """
Hongkun Yu's avatar
Hongkun Yu committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  if transformer_encoder_cls is not None:
    # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
    embedding_cfg = dict(
        vocab_size=bert_config.vocab_size,
        type_vocab_size=bert_config.type_vocab_size,
        hidden_size=bert_config.hidden_size,
        seq_length=sequence_length,
        max_seq_length=bert_config.max_position_embeddings,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
        dropout_rate=bert_config.hidden_dropout_prob,
    )
    hidden_cfg = dict(
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size=bert_config.intermediate_size,
        intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
        dropout_rate=bert_config.hidden_dropout_prob,
        attention_dropout_rate=bert_config.attention_probs_dropout_prob,
Chen Chen's avatar
Chen Chen committed
137
138
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
Hongkun Yu's avatar
Hongkun Yu committed
139
    )
140
141
142
143
144
    kwargs = dict(
        embedding_cfg=embedding_cfg,
        hidden_cfg=hidden_cfg,
        num_hidden_instances=bert_config.num_hidden_layers,
        pooled_output_dim=bert_config.hidden_size,
Chen Chen's avatar
Chen Chen committed
145
146
        pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range))
Hongkun Yu's avatar
Hongkun Yu committed
147
148
149
150

    # Relies on gin configuration to define the Transformer encoder arguments.
    return transformer_encoder_cls(**kwargs)

Chen Chen's avatar
Chen Chen committed
151
  kwargs = dict(
152
153
154
155
156
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
157
      activation=tf_utils.get_activation(bert_config.hidden_act),
158
159
160
161
162
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=sequence_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
163
      embedding_width=bert_config.embedding_size,
164
      initializer=tf.keras.initializers.TruncatedNormal(
Zongwei Zhou's avatar
Zongwei Zhou committed
165
          stddev=bert_config.initializer_range))
166
  if isinstance(bert_config, albert_configs.AlbertConfig):
Chen Chen's avatar
Chen Chen committed
167
168
    return networks.AlbertTransformerEncoder(**kwargs)
  else:
169
    assert isinstance(bert_config, configs.BertConfig)
170
    kwargs['output_range'] = output_range
Chen Chen's avatar
Chen Chen committed
171
    return networks.TransformerEncoder(**kwargs)
172
173


174
175
176
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
177
178
                   initializer=None,
                   use_next_sentence_label=True):
179
180
181
182
183
184
185
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
186
      initializer: Initializer for weights in BertPretrainer.
187
      use_next_sentence_label: Whether to use the next sentence label.
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
203
204
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
205
206
207
208
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
209
210
211
212
213
214

  if use_next_sentence_label:
    next_sentence_labels = tf.keras.layers.Input(
        shape=(1,), name='next_sentence_labels', dtype=tf.int32)
  else:
    next_sentence_labels = None
215

Chen Chen's avatar
Chen Chen committed
216
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
217
218
219
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
220
  pretrainer_model = models.BertPretrainer(
Chen Chen's avatar
Chen Chen committed
221
      network=transformer_encoder,
222
      embedding_table=transformer_encoder.get_embedding_table(),
Chen Chen's avatar
Chen Chen committed
223
      num_classes=2,  # The next sentence prediction label has two classes.
Hongkun Yu's avatar
Hongkun Yu committed
224
      activation=tf_utils.get_activation(bert_config.hidden_act),
Chen Chen's avatar
Chen Chen committed
225
      num_token_predictions=max_predictions_per_seq,
226
      initializer=initializer,
Chen Chen's avatar
Chen Chen committed
227
      output='predictions')
228

Chen Chen's avatar
Chen Chen committed
229
230
231
232
233
  lm_output, sentence_output = pretrainer_model(
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
234
235
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
236
237
238
239
240
241
242
243
244
245
246
247
  inputs = {
      'input_word_ids': input_word_ids,
      'input_mask': input_mask,
      'input_type_ids': input_type_ids,
      'masked_lm_positions': masked_lm_positions,
      'masked_lm_ids': masked_lm_ids,
      'masked_lm_weights': masked_lm_weights,
  }
  if use_next_sentence_label:
    inputs['next_sentence_labels'] = next_sentence_labels

  keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
Chen Chen's avatar
Chen Chen committed
248
  return keras_model, transformer_encoder
249
250


Hongkun Yu's avatar
Hongkun Yu committed
251
252
253
def squad_model(bert_config,
                max_seq_length,
                initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
254
255
                hub_module_url=None,
                hub_module_trainable=True):
256
257
258
259
260
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
Chen Chen's avatar
Chen Chen committed
261
262
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
263
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
264
    hub_module_trainable: True to finetune layers in the hub module.
265
266

  Returns:
267
268
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
269
  """
Chen Chen's avatar
Chen Chen committed
270
271
272
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
273
  if not hub_module_url:
Zongwei Zhou's avatar
Zongwei Zhou committed
274
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
275
    return models.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
276
        network=bert_encoder, initializer=initializer), bert_encoder
277

278
  input_word_ids = tf.keras.layers.Input(
279
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
280
281
282
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
283
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
284
  core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
285
  pooled_output, sequence_output = core_model(
Chen Chen's avatar
Chen Chen committed
286
      [input_word_ids, input_mask, input_type_ids])
287
  bert_encoder = tf.keras.Model(
288
      inputs={
289
          'input_word_ids': input_word_ids,
290
          'input_mask': input_mask,
291
          'input_type_ids': input_type_ids,
292
      },
293
294
      outputs=[sequence_output, pooled_output],
      name='core_model')
295
  return models.BertSpanLabeler(
296
      network=bert_encoder, initializer=initializer), bert_encoder
297
298
299
300


def classifier_model(bert_config,
                     num_labels,
Hongkun Yu's avatar
Hongkun Yu committed
301
                     max_seq_length=None,
302
                     final_layer_initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
303
304
                     hub_module_url=None,
                     hub_module_trainable=True):
305
306
307
308
309
310
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
311
312
    bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
      ALBERT model.
313
314
315
316
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
317
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
318
    hub_module_trainable: True to finetune layers in the hub module.
319
320
321
322
323
324
325
326
327
328
329

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
330
  if not hub_module_url:
331
332
    bert_encoder = get_transformer_encoder(
        bert_config, max_seq_length, output_range=1)
333
    return models.BertClassifier(
Hongkun Yu's avatar
Hongkun Yu committed
334
335
336
337
338
339
340
341
342
343
344
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
345
  bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
Hongkun Yu's avatar
Hongkun Yu committed
346
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
347
348
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
349

350
  output = tf.keras.layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
351
      num_labels, kernel_initializer=initializer, name='output')(
352
353
354
355
356
357
358
359
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model