bert_models.py 13.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

Hongkun Yu's avatar
Hongkun Yu committed
21
import gin
22
import tensorflow as tf
23
import tensorflow_hub as hub
24

25
from official.modeling import tf_utils
26
from official.nlp.albert import configs as albert_configs
27
from official.nlp.bert import configs
Chen Chen's avatar
Chen Chen committed
28
from official.nlp.modeling import losses
29
from official.nlp.modeling import models
Hongkun Yu's avatar
Hongkun Yu committed
30
from official.nlp.modeling import networks
31
32
33
34
35


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
36
  def __init__(self, vocab_size, **kwargs):
37
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
38
39
40
41
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
42
43

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
44
45
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
46
    """Adds metrics."""
47
48
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
49
50
51
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
52
53
54
55
56
57
58
59
60
61
62
63
64
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

    next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        sentence_labels, sentence_output)
    self.add_metric(
        next_sentence_accuracy,
        name='next_sentence_accuracy',
        aggregation='mean')

    self.add_metric(
Chen Chen's avatar
Chen Chen committed
65
        next_sentence_loss, name='next_sentence_loss', aggregation='mean')
66

67
68
  def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
           sentence_labels):
69
    """Implements call() for the layer."""
70
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
71
72
    lm_output = tf.cast(lm_output, tf.float32)
    sentence_output = tf.cast(sentence_output, tf.float32)
Chen Chen's avatar
Chen Chen committed
73
74
75
76
77

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
    sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=sentence_labels, predictions=sentence_output)
78
    loss = mask_label_loss + sentence_loss
79
    batch_shape = tf.slice(tf.shape(sentence_labels), [0], [1])
80
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
81
    final_loss = tf.fill(batch_shape, loss)
82
83

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
84
85
                      mask_label_loss, sentence_output, sentence_labels,
                      sentence_loss)
86
87
88
    return final_loss


Hongkun Yu's avatar
Hongkun Yu committed
89
90
91
92
@gin.configurable
def get_transformer_encoder(bert_config,
                            sequence_length,
                            transformer_encoder_cls=None):
93
94
95
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
96
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
97
    sequence_length: Maximum sequence length of the training data.
Hongkun Yu's avatar
Hongkun Yu committed
98
99
    transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
      default BERT encoder implementation.
100
101
102
103

  Returns:
    A networks.TransformerEncoder object.
  """
Hongkun Yu's avatar
Hongkun Yu committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
  if transformer_encoder_cls is not None:
    # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
    embedding_cfg = dict(
        vocab_size=bert_config.vocab_size,
        type_vocab_size=bert_config.type_vocab_size,
        hidden_size=bert_config.hidden_size,
        seq_length=sequence_length,
        max_seq_length=bert_config.max_position_embeddings,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=bert_config.initializer_range),
        dropout_rate=bert_config.hidden_dropout_prob,
    )
    hidden_cfg = dict(
        num_attention_heads=bert_config.num_attention_heads,
        intermediate_size=bert_config.intermediate_size,
        intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
        dropout_rate=bert_config.hidden_dropout_prob,
        attention_dropout_rate=bert_config.attention_probs_dropout_prob,
    )
    kwargs = dict(embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg,
Chen Chen's avatar
Chen Chen committed
124
125
                  num_hidden_instances=bert_config.num_hidden_layers,
                  num_output_classes=bert_config.hidden_size)
Hongkun Yu's avatar
Hongkun Yu committed
126
127
128
129

    # Relies on gin configuration to define the Transformer encoder arguments.
    return transformer_encoder_cls(**kwargs)

Chen Chen's avatar
Chen Chen committed
130
  kwargs = dict(
131
132
133
134
135
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
136
      activation=tf_utils.get_activation(bert_config.hidden_act),
137
138
139
140
141
142
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=sequence_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
Zongwei Zhou's avatar
Zongwei Zhou committed
143
          stddev=bert_config.initializer_range))
144
  if isinstance(bert_config, albert_configs.AlbertConfig):
Chen Chen's avatar
Chen Chen committed
145
146
147
    kwargs['embedding_width'] = bert_config.embedding_size
    return networks.AlbertTransformerEncoder(**kwargs)
  else:
148
    assert isinstance(bert_config, configs.BertConfig)
Chen Chen's avatar
Chen Chen committed
149
    return networks.TransformerEncoder(**kwargs)
150
151


152
153
154
155
156
157
158
159
160
161
162
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
163
      initializer: Initializer for weights in BertPretrainer.
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
179
180
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
181
182
183
184
185
186
187
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
  next_sentence_labels = tf.keras.layers.Input(
      shape=(1,), name='next_sentence_labels', dtype=tf.int32)

Chen Chen's avatar
Chen Chen committed
188
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
189
190
191
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
192
  pretrainer_model = models.BertPretrainer(
Chen Chen's avatar
Chen Chen committed
193
194
195
      network=transformer_encoder,
      num_classes=2,  # The next sentence prediction label has two classes.
      num_token_predictions=max_predictions_per_seq,
196
      initializer=initializer,
Chen Chen's avatar
Chen Chen committed
197
      output='predictions')
198

Chen Chen's avatar
Chen Chen committed
199
200
201
202
203
  lm_output, sentence_output = pretrainer_model(
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
204
205
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
Chen Chen's avatar
Chen Chen committed
206
  keras_model = tf.keras.Model(
207
208
209
210
211
212
213
214
215
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids,
          'masked_lm_positions': masked_lm_positions,
          'masked_lm_ids': masked_lm_ids,
          'masked_lm_weights': masked_lm_weights,
          'next_sentence_labels': next_sentence_labels,
      },
Chen Chen's avatar
Chen Chen committed
216
217
      outputs=output_loss)
  return keras_model, transformer_encoder
218
219


Hongkun Yu's avatar
Hongkun Yu committed
220
221
222
def squad_model(bert_config,
                max_seq_length,
                initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
223
224
                hub_module_url=None,
                hub_module_trainable=True):
225
226
227
228
229
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
Chen Chen's avatar
Chen Chen committed
230
231
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
232
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
233
    hub_module_trainable: True to finetune layers in the hub module.
234
235

  Returns:
236
237
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
238
  """
Chen Chen's avatar
Chen Chen committed
239
240
241
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
242
  if not hub_module_url:
Zongwei Zhou's avatar
Zongwei Zhou committed
243
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
244
    return models.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
245
        network=bert_encoder, initializer=initializer), bert_encoder
246

247
  input_word_ids = tf.keras.layers.Input(
248
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
249
250
251
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
252
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
253
  core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
254
  pooled_output, sequence_output = core_model(
Chen Chen's avatar
Chen Chen committed
255
      [input_word_ids, input_mask, input_type_ids])
256
  bert_encoder = tf.keras.Model(
257
      inputs={
258
          'input_word_ids': input_word_ids,
259
          'input_mask': input_mask,
260
          'input_type_ids': input_type_ids,
261
      },
262
263
      outputs=[sequence_output, pooled_output],
      name='core_model')
264
  return models.BertSpanLabeler(
265
      network=bert_encoder, initializer=initializer), bert_encoder
266
267
268
269
270


def classifier_model(bert_config,
                     num_labels,
                     max_seq_length,
271
                     final_layer_initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
272
273
                     hub_module_url=None,
                     hub_module_trainable=True):
274
275
276
277
278
279
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
280
281
    bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
      ALBERT model.
282
283
284
285
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
286
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
287
    hub_module_trainable: True to finetune layers in the hub module.
288
289
290
291
292
293
294
295
296
297
298

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
299
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
300
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
301
    return models.BertClassifier(
Hongkun Yu's avatar
Hongkun Yu committed
302
303
304
305
306
307
308
309
310
311
312
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
313
314
  bert_model = hub.KerasLayer(
      hub_module_url, trainable=hub_module_trainable)
Hongkun Yu's avatar
Hongkun Yu committed
315
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
316
317
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
318

319
  output = tf.keras.layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
320
      num_labels, kernel_initializer=initializer, name='output')(
321
322
323
324
325
326
327
328
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model