bert_models.py 11.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT models that are compatible with TF 2.0."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
22
import tensorflow_hub as hub
23

24
from official.modeling import tf_utils
25
from official.nlp.albert import configs as albert_configs
26
from official.nlp.bert import configs
Chen Chen's avatar
Chen Chen committed
27
from official.nlp.modeling import losses
28
from official.nlp.modeling import models
Hongkun Yu's avatar
Hongkun Yu committed
29
from official.nlp.modeling import networks
30
31
32
33
34


class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
  """Returns layer that computes custom loss and metrics for pretraining."""

Chen Chen's avatar
Chen Chen committed
35
  def __init__(self, vocab_size, **kwargs):
36
    super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
Chen Chen's avatar
Chen Chen committed
37
38
39
40
    self._vocab_size = vocab_size
    self.config = {
        'vocab_size': vocab_size,
    }
41
42

  def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
43
44
                   lm_example_loss, sentence_output, sentence_labels,
                   next_sentence_loss):
45
    """Adds metrics."""
46
47
    masked_lm_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        lm_labels, lm_output)
48
49
50
    numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
    denominator = tf.reduce_sum(lm_label_weights) + 1e-5
    masked_lm_accuracy = numerator / denominator
51
52
53
54
55
56
57
58
59
60
61
62
63
    self.add_metric(
        masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')

    self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')

    next_sentence_accuracy = tf.keras.metrics.sparse_categorical_accuracy(
        sentence_labels, sentence_output)
    self.add_metric(
        next_sentence_accuracy,
        name='next_sentence_accuracy',
        aggregation='mean')

    self.add_metric(
Chen Chen's avatar
Chen Chen committed
64
        next_sentence_loss, name='next_sentence_loss', aggregation='mean')
65

66
67
  def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
           sentence_labels):
68
    """Implements call() for the layer."""
69
    lm_label_weights = tf.cast(lm_label_weights, tf.float32)
70
71
    lm_output = tf.cast(lm_output, tf.float32)
    sentence_output = tf.cast(sentence_output, tf.float32)
Chen Chen's avatar
Chen Chen committed
72
73
74
75
76

    mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
    sentence_loss = losses.weighted_sparse_categorical_crossentropy_loss(
        labels=sentence_labels, predictions=sentence_output)
77
    loss = mask_label_loss + sentence_loss
78
    batch_shape = tf.slice(tf.shape(sentence_labels), [0], [1])
79
    # TODO(hongkuny): Avoids the hack and switches add_loss.
Chen Chen's avatar
Chen Chen committed
80
    final_loss = tf.fill(batch_shape, loss)
81
82

    self._add_metrics(lm_output, lm_label_ids, lm_label_weights,
Chen Chen's avatar
Chen Chen committed
83
84
                      mask_label_loss, sentence_output, sentence_labels,
                      sentence_loss)
85
86
87
    return final_loss


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
88
def get_transformer_encoder(bert_config, sequence_length):
89
90
91
  """Gets a 'TransformerEncoder' object.

  Args:
Chen Chen's avatar
Chen Chen committed
92
    bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
93
94
95
96
97
    sequence_length: Maximum sequence length of the training data.

  Returns:
    A networks.TransformerEncoder object.
  """
Chen Chen's avatar
Chen Chen committed
98
  kwargs = dict(
99
100
101
102
103
      vocab_size=bert_config.vocab_size,
      hidden_size=bert_config.hidden_size,
      num_layers=bert_config.num_hidden_layers,
      num_attention_heads=bert_config.num_attention_heads,
      intermediate_size=bert_config.intermediate_size,
Chen Chen's avatar
Chen Chen committed
104
      activation=tf_utils.get_activation(bert_config.hidden_act),
105
106
107
108
109
110
      dropout_rate=bert_config.hidden_dropout_prob,
      attention_dropout_rate=bert_config.attention_probs_dropout_prob,
      sequence_length=sequence_length,
      max_sequence_length=bert_config.max_position_embeddings,
      type_vocab_size=bert_config.type_vocab_size,
      initializer=tf.keras.initializers.TruncatedNormal(
Zongwei Zhou's avatar
Zongwei Zhou committed
111
          stddev=bert_config.initializer_range))
112
  if isinstance(bert_config, albert_configs.AlbertConfig):
Chen Chen's avatar
Chen Chen committed
113
114
115
    kwargs['embedding_width'] = bert_config.embedding_size
    return networks.AlbertTransformerEncoder(**kwargs)
  else:
116
    assert isinstance(bert_config, configs.BertConfig)
Chen Chen's avatar
Chen Chen committed
117
    return networks.TransformerEncoder(**kwargs)
118
119


120
121
122
123
124
125
126
127
128
129
130
def pretrain_model(bert_config,
                   seq_length,
                   max_predictions_per_seq,
                   initializer=None):
  """Returns model to be used for pre-training.

  Args:
      bert_config: Configuration that defines the core BERT model.
      seq_length: Maximum sequence length of the training data.
      max_predictions_per_seq: Maximum number of tokens in sequence to mask out
        and use for pretraining.
Chen Chen's avatar
Chen Chen committed
131
      initializer: Initializer for weights in BertPretrainer.
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

  Returns:
      Pretraining model as well as core BERT submodel from which to save
      weights after pretraining.
  """
  input_word_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
  input_mask = tf.keras.layers.Input(
      shape=(seq_length,), name='input_mask', dtype=tf.int32)
  input_type_ids = tf.keras.layers.Input(
      shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
  masked_lm_positions = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_positions',
      dtype=tf.int32)
Chen Chen's avatar
Chen Chen committed
147
148
  masked_lm_ids = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
149
150
151
152
153
154
155
  masked_lm_weights = tf.keras.layers.Input(
      shape=(max_predictions_per_seq,),
      name='masked_lm_weights',
      dtype=tf.int32)
  next_sentence_labels = tf.keras.layers.Input(
      shape=(1,), name='next_sentence_labels', dtype=tf.int32)

Chen Chen's avatar
Chen Chen committed
156
  transformer_encoder = get_transformer_encoder(bert_config, seq_length)
Chen Chen's avatar
Chen Chen committed
157
158
159
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
160
  pretrainer_model = models.BertPretrainer(
Chen Chen's avatar
Chen Chen committed
161
162
163
      network=transformer_encoder,
      num_classes=2,  # The next sentence prediction label has two classes.
      num_token_predictions=max_predictions_per_seq,
164
      initializer=initializer,
Chen Chen's avatar
Chen Chen committed
165
      output='predictions')
166

Chen Chen's avatar
Chen Chen committed
167
168
169
170
171
  lm_output, sentence_output = pretrainer_model(
      [input_word_ids, input_mask, input_type_ids, masked_lm_positions])

  pretrain_loss_layer = BertPretrainLossAndMetricLayer(
      vocab_size=bert_config.vocab_size)
172
173
  output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
                                    masked_lm_weights, next_sentence_labels)
Chen Chen's avatar
Chen Chen committed
174
  keras_model = tf.keras.Model(
175
176
177
178
179
180
181
182
183
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids,
          'masked_lm_positions': masked_lm_positions,
          'masked_lm_ids': masked_lm_ids,
          'masked_lm_weights': masked_lm_weights,
          'next_sentence_labels': next_sentence_labels,
      },
Chen Chen's avatar
Chen Chen committed
184
185
      outputs=output_loss)
  return keras_model, transformer_encoder
186
187


Hongkun Yu's avatar
Hongkun Yu committed
188
189
190
def squad_model(bert_config,
                max_seq_length,
                initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
191
192
                hub_module_url=None,
                hub_module_trainable=True):
193
194
195
196
197
  """Returns BERT Squad model along with core BERT model to import weights.

  Args:
    bert_config: BertConfig, the config defines the core Bert model.
    max_seq_length: integer, the maximum input sequence length.
Chen Chen's avatar
Chen Chen committed
198
199
    initializer: Initializer for the final dense layer in the span labeler.
      Defaulted to TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
200
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
201
    hub_module_trainable: True to finetune layers in the hub module.
202
203

  Returns:
204
205
    A tuple of (1) keras model that outputs start logits and end logits and
    (2) the core BERT transformer encoder.
206
  """
Chen Chen's avatar
Chen Chen committed
207
208
209
  if initializer is None:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)
Chen Chen's avatar
Chen Chen committed
210
  if not hub_module_url:
Zongwei Zhou's avatar
Zongwei Zhou committed
211
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
212
    return models.BertSpanLabeler(
Chen Chen's avatar
Chen Chen committed
213
        network=bert_encoder, initializer=initializer), bert_encoder
214

215
  input_word_ids = tf.keras.layers.Input(
216
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
217
218
219
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
220
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
221
  core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
222
  pooled_output, sequence_output = core_model(
Chen Chen's avatar
Chen Chen committed
223
      [input_word_ids, input_mask, input_type_ids])
224
  bert_encoder = tf.keras.Model(
225
      inputs={
226
          'input_word_ids': input_word_ids,
227
          'input_mask': input_mask,
228
          'input_type_ids': input_type_ids,
229
      },
230
231
      outputs=[sequence_output, pooled_output],
      name='core_model')
232
  return models.BertSpanLabeler(
233
      network=bert_encoder, initializer=initializer), bert_encoder
234
235
236
237
238


def classifier_model(bert_config,
                     num_labels,
                     max_seq_length,
239
                     final_layer_initializer=None,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
240
241
                     hub_module_url=None,
                     hub_module_trainable=True):
242
243
244
245
246
247
  """BERT classifier model in functional API style.

  Construct a Keras model for predicting `num_labels` outputs from an input with
  maximum sequence length `max_seq_length`.

  Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
248
249
    bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
      ALBERT model.
250
251
252
253
    num_labels: integer, the number of classes.
    max_seq_length: integer, the maximum input sequence length.
    final_layer_initializer: Initializer for final dense layer. Defaulted
      TruncatedNormal initializer.
Hongkun Yu's avatar
Hongkun Yu committed
254
    hub_module_url: TF-Hub path/url to Bert module.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
255
    hub_module_trainable: True to finetune layers in the hub module.
256
257
258
259
260
261
262
263
264
265
266

  Returns:
    Combined prediction model (words, mask, type) -> (one-hot labels)
    BERT sub-model (words, mask, type) -> (bert_outputs)
  """
  if final_layer_initializer is not None:
    initializer = final_layer_initializer
  else:
    initializer = tf.keras.initializers.TruncatedNormal(
        stddev=bert_config.initializer_range)

Hongkun Yu's avatar
Hongkun Yu committed
267
  if not hub_module_url:
Chen Chen's avatar
Chen Chen committed
268
    bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
269
    return models.BertClassifier(
Hongkun Yu's avatar
Hongkun Yu committed
270
271
272
273
274
275
276
277
278
279
280
        bert_encoder,
        num_classes=num_labels,
        dropout_rate=bert_config.hidden_dropout_prob,
        initializer=initializer), bert_encoder

  input_word_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
281
282
  bert_model = hub.KerasLayer(
      hub_module_url, trainable=hub_module_trainable)
Hongkun Yu's avatar
Hongkun Yu committed
283
  pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
284
285
  output = tf.keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
      pooled_output)
Hongkun Yu's avatar
Hongkun Yu committed
286

287
  output = tf.keras.layers.Dense(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
288
      num_labels, kernel_initializer=initializer, name='output')(
289
290
291
292
293
294
295
296
          output)
  return tf.keras.Model(
      inputs={
          'input_word_ids': input_word_ids,
          'input_mask': input_mask,
          'input_type_ids': input_type_ids
      },
      outputs=output), bert_model