modeling_tf_bert.py 54.6 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 BERT model. """


import logging

import numpy as np
import tensorflow as tf

from .configuration_bert import BertConfig
Lysandre's avatar
Lysandre committed
25
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
Aymeric Augustin's avatar
Aymeric Augustin committed
26
27
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list

thomwolf's avatar
thomwolf committed
28
29
30
31
32

logger = logging.getLogger(__name__)


TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5",
    "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-tf_model.h5",
    "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-tf_model.h5",
    "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-tf_model.h5",
    "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-tf_model.h5",
    "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-tf_model.h5",
    "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-tf_model.h5",
    "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-tf_model.h5",
    "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-tf_model.h5",
    "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-tf_model.h5",
    "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5",
    "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5",
    "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5",
    "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-tf_model.h5",
    "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-tf_model.h5",
    "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-tf_model.h5",
    "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5",
    "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/tf_model.h5",
    "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/tf_model.h5",
thomwolf's avatar
thomwolf committed
52
53
54
55
}


def gelu(x):
thomwolf's avatar
thomwolf committed
56
    """ Gaussian Error Linear Unit.
Santiago Castro's avatar
Santiago Castro committed
57
    Original Implementation of the gelu activation function in Google Bert repo when initially created.
thomwolf's avatar
thomwolf committed
58
59
60
61
62
63
64
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
    return x * cdf

65

thomwolf's avatar
thomwolf committed
66
def gelu_new(x):
thomwolf's avatar
thomwolf committed
67
68
69
70
71
72
73
74
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
75
    cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
thomwolf's avatar
thomwolf committed
76
77
    return x * cdf

78

thomwolf's avatar
thomwolf committed
79
80
81
82
def swish(x):
    return x * tf.sigmoid(x)


83
84
85
86
87
88
ACT2FN = {
    "gelu": tf.keras.layers.Activation(gelu),
    "relu": tf.keras.activations.relu,
    "swish": tf.keras.layers.Activation(swish),
    "gelu_new": tf.keras.layers.Activation(gelu_new),
}
thomwolf's avatar
thomwolf committed
89
90
91
92
93


class TFBertEmbeddings(tf.keras.layers.Layer):
    """Construct the embeddings from word, position and token_type embeddings.
    """
94

thomwolf's avatar
thomwolf committed
95
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
96
        super().__init__(**kwargs)
97
98
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size
thomwolf's avatar
thomwolf committed
99
        self.initializer_range = config.initializer_range
100

101
102
103
104
105
106
107
108
109
110
111
112
        self.position_embeddings = tf.keras.layers.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="position_embeddings",
        )
        self.token_type_embeddings = tf.keras.layers.Embedding(
            config.type_vocab_size,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="token_type_embeddings",
        )
thomwolf's avatar
thomwolf committed
113
114
115

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
116
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
117
118
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

119
120
121
122
123
124
125
126
    def build(self, input_shape):
        """Build shared word embedding layer """
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight",
                shape=[self.vocab_size, self.hidden_size],
127
128
                initializer=get_initializer(self.initializer_range),
            )
Julien Chaumond's avatar
Julien Chaumond committed
129
        super().build(input_shape)
130
131
132
133
134
135
136
137
138
139
140
141

    def call(self, inputs, mode="embedding", training=False):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
142

143
144
145
146
147
148
149
150
151
152
153
154
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(inputs, training=training)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, inputs, training=False):
        """Applies embedding based on inputs tensor."""
155
        input_ids, position_ids, token_type_ids, inputs_embeds = inputs
thomwolf's avatar
thomwolf committed
156

157
        if input_ids is not None:
158
            input_shape = shape_list(input_ids)
159
        else:
160
            input_shape = shape_list(inputs_embeds)[:-1]
161

162
        seq_length = input_shape[1]
thomwolf's avatar
thomwolf committed
163
164
165
        if position_ids is None:
            position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
        if token_type_ids is None:
166
            token_type_ids = tf.fill(input_shape, 0)
thomwolf's avatar
thomwolf committed
167

168
169
        if inputs_embeds is None:
            inputs_embeds = tf.gather(self.word_embeddings, input_ids)
thomwolf's avatar
thomwolf committed
170
171
172
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

173
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
thomwolf's avatar
thomwolf committed
174
        embeddings = self.LayerNorm(embeddings)
thomwolf's avatar
thomwolf committed
175
        embeddings = self.dropout(embeddings, training=training)
thomwolf's avatar
thomwolf committed
176
177
        return embeddings

178
179
180
181
182
183
184
    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
                inputs: A float32 tensor with shape [batch_size, length, hidden_size]
            Returns:
                float32 tensor with shape [batch_size, length, vocab_size].
        """
185
186
        batch_size = shape_list(inputs)[0]
        length = shape_list(inputs)[1]
187
188
189
190
191
192

        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])

thomwolf's avatar
thomwolf committed
193
194
195

class TFBertSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
196
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
197
198
199
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
200
201
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
thomwolf's avatar
thomwolf committed
202
203
204
205
206
207
208
        self.output_attentions = config.output_attentions

        self.num_attention_heads = config.num_attention_heads
        assert config.hidden_size % config.num_attention_heads == 0
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

209
210
211
212
213
214
215
216
217
        self.query = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
thomwolf's avatar
thomwolf committed
218
219
220
221
222
223
224
225
226
227

        self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, inputs, training=False):
        hidden_states, attention_mask, head_mask = inputs

228
        batch_size = shape_list(hidden_states)[0]
thomwolf's avatar
thomwolf committed
229
230
231
232
233
234
235
236
237
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        # Take the dot product between "query" and "key" to get the raw attention scores.
238
239
240
241
        attention_scores = tf.matmul(
            query_layer, key_layer, transpose_b=True
        )  # (batch size, num_heads, seq_len_q, seq_len_k)
        dk = tf.cast(shape_list(key_layer)[-1], tf.float32)  # scale attention_scores
thomwolf's avatar
thomwolf committed
242
        attention_scores = attention_scores / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
243
244
245
246

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
            attention_scores = attention_scores + attention_mask
thomwolf's avatar
thomwolf committed
247
248
249
250

        # Normalize the attention scores to probabilities.
        attention_probs = tf.nn.softmax(attention_scores, axis=-1)

thomwolf's avatar
thomwolf committed
251
252
253
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs, training=training)
thomwolf's avatar
thomwolf committed
254
255
256
257
258
259
260
261

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = tf.matmul(attention_probs, value_layer)

        context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
262
263
264
        context_layer = tf.reshape(
            context_layer, (batch_size, -1, self.all_head_size)
        )  # (batch_size, seq_len_q, all_head_size)
thomwolf's avatar
thomwolf committed
265
266
267
268
269
270
271

        outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
        return outputs


class TFBertSelfOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
272
        super().__init__(**kwargs)
273
274
275
276
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
277
278
279
280
281
282
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def call(self, inputs, training=False):
        hidden_states, input_tensor = inputs

        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
283
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
284
285
286
287
288
289
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class TFBertAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
290
        super().__init__(**kwargs)
291
292
        self.self_attention = TFBertSelfAttention(config, name="self")
        self.dense_output = TFBertSelfOutput(config, name="output")
thomwolf's avatar
thomwolf committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

    def prune_heads(self, heads):
        raise NotImplementedError

    def call(self, inputs, training=False):
        input_tensor, attention_mask, head_mask = inputs

        self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training)
        attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


class TFBertIntermediate(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
308
        super().__init__(**kwargs)
309
310
311
        self.dense = tf.keras.layers.Dense(
            config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
312
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
313
314
315
316
317
318
319
320
321
322
323
324
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class TFBertOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
325
        super().__init__(**kwargs)
326
327
328
329
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
330
331
332
333
334
335
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def call(self, inputs, training=False):
        hidden_states, input_tensor = inputs

        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
336
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
337
338
339
340
341
342
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class TFBertLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
343
        super().__init__(**kwargs)
344
345
346
        self.attention = TFBertAttention(config, name="attention")
        self.intermediate = TFBertIntermediate(config, name="intermediate")
        self.bert_output = TFBertOutput(config, name="output")
thomwolf's avatar
thomwolf committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360

    def call(self, inputs, training=False):
        hidden_states, attention_mask, head_mask = inputs

        attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
        attention_output = attention_outputs[0]
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.bert_output([intermediate_output, attention_output], training=training)
        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them
        return outputs


class TFBertEncoder(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
361
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
362
363
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
364
        self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
thomwolf's avatar
thomwolf committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

    def call(self, inputs, training=False):
        hidden_states, attention_mask, head_mask = inputs

        all_hidden_states = ()
        all_attentions = ()
        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training)
            hidden_states = layer_outputs[0]

            if self.output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # Add last layer
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            outputs = outputs + (all_attentions,)
        return outputs  # outputs, (hidden states), (attentions)


class TFBertPooler(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
395
        super().__init__(**kwargs)
396
397
398
399
400
401
        self.dense = tf.keras.layers.Dense(
            config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )
thomwolf's avatar
thomwolf committed
402
403
404
405
406
407
408
409
410
411
412

    def call(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        return pooled_output


class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
413
        super().__init__(**kwargs)
414
415
416
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
417
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
418
419
420
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
421
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
422
423
424
425
426
427
428
429
430

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class TFBertLMPredictionHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
431
    def __init__(self, config, input_embeddings, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
432
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
433
        self.vocab_size = config.vocab_size
434
        self.transform = TFBertPredictionHeadTransform(config, name="transform")
thomwolf's avatar
thomwolf committed
435
436
437

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
thomwolf's avatar
thomwolf committed
438
        self.input_embeddings = input_embeddings
thomwolf's avatar
thomwolf committed
439
440

    def build(self, input_shape):
441
        self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
Julien Chaumond's avatar
Julien Chaumond committed
442
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
443
444
445

    def call(self, hidden_states):
        hidden_states = self.transform(hidden_states)
thomwolf's avatar
thomwolf committed
446
447
        hidden_states = self.input_embeddings(hidden_states, mode="linear")
        hidden_states = hidden_states + self.bias
thomwolf's avatar
thomwolf committed
448
449
450
451
        return hidden_states


class TFBertMLMHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
452
    def __init__(self, config, input_embeddings, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
453
        super().__init__(**kwargs)
454
        self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
thomwolf's avatar
thomwolf committed
455
456
457
458
459
460
461
462

    def call(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class TFBertNSPHead(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
463
        super().__init__(**kwargs)
464
465
466
        self.seq_relationship = tf.keras.layers.Dense(
            2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
        )
thomwolf's avatar
thomwolf committed
467
468
469
470
471
472
473
474

    def call(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class TFBertMainLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
475
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
476
477
        self.num_hidden_layers = config.num_hidden_layers

478
479
480
        self.embeddings = TFBertEmbeddings(config, name="embeddings")
        self.encoder = TFBertEncoder(config, name="encoder")
        self.pooler = TFBertPooler(config, name="pooler")
thomwolf's avatar
thomwolf committed
481

482
483
484
    def get_input_embeddings(self):
        return self.embeddings

thomwolf's avatar
thomwolf committed
485
486
487
488
489
490
491
492
493
494
    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        raise NotImplementedError

495
496
497
498
499
500
501
502
503
504
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        training=False,
    ):
thomwolf's avatar
thomwolf committed
505
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
506
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
507
508
509
510
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
511
512
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
            assert len(inputs) <= 6, "Too many inputs."
thomwolf's avatar
thomwolf committed
513
        elif isinstance(inputs, dict):
514
515
516
517
518
519
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
520
            assert len(inputs) <= 6, "Too many inputs."
thomwolf's avatar
thomwolf committed
521
522
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
523

524
525
526
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
527
            input_shape = shape_list(input_ids)
528
        elif inputs_embeds is not None:
529
            input_shape = shape_list(inputs_embeds)[:-1]
530
531
532
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

thomwolf's avatar
thomwolf committed
533
        if attention_mask is None:
534
            attention_mask = tf.fill(input_shape, 1)
thomwolf's avatar
thomwolf committed
535
        if token_type_ids is None:
536
            token_type_ids = tf.fill(input_shape, 0)
thomwolf's avatar
thomwolf committed
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.

        extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
559
        if head_mask is not None:
thomwolf's avatar
thomwolf committed
560
561
562
563
564
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)

565
        embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
thomwolf's avatar
thomwolf committed
566
567
568
569
570
        encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)

        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

571
572
573
        outputs = (sequence_output, pooled_output,) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
thomwolf's avatar
thomwolf committed
574
575
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)

thomwolf's avatar
thomwolf committed
576

thomwolf's avatar
thomwolf committed
577
578
class TFBertPreTrainedModel(TFPreTrainedModel):
    """ An abstract class to handle weights initialization and
579
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
580
    """
581

thomwolf's avatar
thomwolf committed
582
583
584
585
586
    config_class = BertConfig
    pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
    base_model_prefix = "bert"


Lysandre's avatar
Lysandre committed
587
588
589
BERT_START_DOCSTRING = r"""
    This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class. 
    Use it as a regular TF 2.0 Keras Model and
thomwolf's avatar
thomwolf committed
590
591
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.

Lysandre's avatar
Lysandre committed
592
593
    .. note::
    
thomwolf's avatar
thomwolf committed
594
595
596
597
598
        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

Lysandre's avatar
Lysandre committed
599
600
        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having 
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
thomwolf's avatar
thomwolf committed
601

Lysandre's avatar
Lysandre committed
602
603
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors 
        in the first positional argument :
thomwolf's avatar
thomwolf committed
604

Lysandre's avatar
Lysandre committed
605
        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
thomwolf's avatar
thomwolf committed
606
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
Lysandre's avatar
Lysandre committed
607
608
609
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
thomwolf's avatar
thomwolf committed
610
611

    Parameters:
612
        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
thomwolf's avatar
thomwolf committed
613
            Initializing with a config file does not load the weights associated with the model, only the configuration.
614
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
615
616
617
"""

BERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
618
619
620
621
    Args:
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. 
            
622
623
            Indices can be obtained using :class:`transformers.BertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
Lysandre's avatar
Lysandre committed
624
625
626
627
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
            
            `What are input IDs? <../glossary.html#input-ids>`__
        attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
628
629
630
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
631
632
633
            
            `What are attention masks? <../glossary.html#attention-mask>`__
        token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 
thomwolf's avatar
thomwolf committed
634
635
636
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
Lysandre's avatar
Lysandre committed
637
638
639

            `What are token type IDs? <../glossary.html#token-type-ids>`__
        position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
640
641
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
642
643
644
            
            `What are position IDs? <../glossary.html#position-ids>`__
        head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
645
646
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
Lysandre committed
647
648
649
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
        inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`, defaults to :obj:`None`): 
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
650
651
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
Lysandre's avatar
Lysandre committed
652
653
654
        training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
            Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
            (if set to :obj:`False`) for evaluation.
thomwolf's avatar
thomwolf committed
655
656
"""

657
658
659
660
661

@add_start_docstrings(
    "The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
662
class TFBertModel(TFBertPreTrainedModel):
thomwolf's avatar
thomwolf committed
663
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
664
        super().__init__(config, *inputs, **kwargs)
665
        self.bert = TFBertMainLayer(config, name="bert")
thomwolf's avatar
thomwolf committed
666

Lysandre's avatar
Lysandre committed
667
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
668
    def call(self, inputs, **kwargs):
Lysandre's avatar
Lysandre committed
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        r"""
        Returns:
            :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
            last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
                Sequence of hidden-states at the output of the last layer of the model.
            pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
                Last layer hidden-state of the first token of the sequence (classification token)
                further processed by a Linear layer and a Tanh activation function. The Linear
                layer weights are trained from the next sentence prediction (classification)
                objective during Bert pretraining. This output is usually *not* a good summary
                of the semantic content of the input, you're often better with averaging or pooling
                the sequence of hidden-states for the whole input sequence.
            hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
                tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
                of shape :obj:`(batch_size, sequence_length, hidden_size)`.

                Hidden-states of the model at the output of each layer plus the initial embedding outputs.
            attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
                tuple of :obj:`tf.Tensor` (one for each layer) of shape
                :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

                Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

        Examples::

            import tensorflow as tf
            from transformers import BertTokenizer, TFBertModel

            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            model = TFBertModel.from_pretrained('bert-base-uncased')
            input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
            outputs = model(input_ids)
            last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
        """
thomwolf's avatar
thomwolf committed
703
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
704
705
706
        return outputs


707
708
@add_start_docstrings(
    """Bert Model with two heads on top as done during the pre-training:
thomwolf's avatar
thomwolf committed
709
    a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
710
711
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
712
class TFBertForPreTraining(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.nsp = TFBertNSPHead(config, name="nsp___cls")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
        prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
729
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Lysandre's avatar
Lysandre committed
730
        seq_relationship_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`):
thomwolf's avatar
thomwolf committed
731
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
Lysandre's avatar
Lysandre committed
732
733
734
735
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
736
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
737
738
739
740
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
741
742
743
744
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
745
        import tensorflow as tf
746
        from transformers import BertTokenizer, TFBertForPreTraining
thomwolf's avatar
thomwolf committed
747

thomwolf's avatar
thomwolf committed
748
749
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForPreTraining.from_pretrained('bert-base-uncased')
750
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
751
752
753
        outputs = model(input_ids)
        prediction_scores, seq_relationship_scores = outputs[:2]

Lysandre's avatar
Lysandre committed
754
        """
thomwolf's avatar
thomwolf committed
755
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
756
757

        sequence_output, pooled_output = outputs[:2]
758
        prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
759
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
760

761
762
763
        outputs = (prediction_scores, seq_relationship_score,) + outputs[
            2:
        ]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
764
765
766
767

        return outputs  # prediction_scores, seq_relationship_score, (hidden_states), (attentions)


Lysandre's avatar
Lysandre committed
768
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
thomwolf's avatar
thomwolf committed
769
class TFBertForMaskedLM(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.BertConfig`) and inputs:
        prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
785
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Lysandre's avatar
Lysandre committed
786
787
788
789
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
790
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
791
792
793
794
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
795
796
797
798
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
799
        import tensorflow as tf
800
        from transformers import BertTokenizer, TFBertForMaskedLM
thomwolf's avatar
thomwolf committed
801

thomwolf's avatar
thomwolf committed
802
803
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForMaskedLM.from_pretrained('bert-base-uncased')
804
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
805
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
806
        prediction_scores = outputs[0]
thomwolf's avatar
thomwolf committed
807

Lysandre's avatar
Lysandre committed
808
        """
thomwolf's avatar
thomwolf committed
809
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
810
811

        sequence_output = outputs[0]
812
        prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
thomwolf's avatar
thomwolf committed
813
814
815
816
817
818

        outputs = (prediction_scores,) + outputs[2:]  # Add hidden states and attention if they are here

        return outputs  # prediction_scores, (hidden_states), (attentions)


819
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
820
    """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
821
)
thomwolf's avatar
thomwolf committed
822
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
823
824
825
826
827
828
829
830
831
832
833
834
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.nsp = TFBertNSPHead(config, name="nsp___cls")

    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.BertConfig`) and inputs:
        seq_relationship_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`)
thomwolf's avatar
thomwolf committed
835
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
Lysandre's avatar
Lysandre committed
836
837
838
839
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
840
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
841
842
843
844
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
845
846
847
848
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
849
        import tensorflow as tf
850
        from transformers import BertTokenizer, TFBertForNextSentencePrediction
thomwolf's avatar
thomwolf committed
851

thomwolf's avatar
thomwolf committed
852
853
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
854
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
855
856
857
        outputs = model(input_ids)
        seq_relationship_scores = outputs[0]

Lysandre's avatar
Lysandre committed
858
        """
thomwolf's avatar
thomwolf committed
859
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
860
861

        pooled_output = outputs[1]
862
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
863
864
865
866

        outputs = (seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
867
868


869
870
@add_start_docstrings(
    """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
871
    the pooled output) e.g. for GLUE tasks. """,
872
873
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
874
class TFBertForSequenceClassification(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
        logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
thomwolf's avatar
thomwolf committed
891
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
Lysandre's avatar
Lysandre committed
892
893
894
895
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
896
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
897
898
899
900
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
901
902
903
904
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
905
        import tensorflow as tf
906
        from transformers import BertTokenizer, TFBertForSequenceClassification
thomwolf's avatar
thomwolf committed
907

thomwolf's avatar
thomwolf committed
908
909
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
910
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
911
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
912
        logits = outputs[0]
thomwolf's avatar
thomwolf committed
913

Lysandre's avatar
Lysandre committed
914
        """
thomwolf's avatar
thomwolf committed
915
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
916
917
918

        pooled_output = outputs[1]

919
        pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
thomwolf's avatar
thomwolf committed
920
921
922
923
924
925
926
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # logits, (hidden_states), (attentions)


927
928
@add_start_docstrings(
    """Bert Model with a multiple choice classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
929
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
930
931
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
932
class TFBertForMultipleChoice(TFBertPreTrainedModel):
thomwolf's avatar
thomwolf committed
933
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
934
        super().__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
935

936
        self.bert = TFBertMainLayer(config, name="bert")
thomwolf's avatar
thomwolf committed
937
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
938
939
940
941
        self.classifier = tf.keras.layers.Dense(
            1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Lysandre's avatar
Lysandre committed
942
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
943
944
945
946
947
948
949
950
951
952
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        training=False,
    ):
Lysandre's avatar
Lysandre committed
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
        r"""
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
        classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
            `num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).

            Classification scores (before SoftMax).
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        import tensorflow as tf
        from transformers import BertTokenizer, TFBertForMultipleChoice

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
        choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
        input_ids = tf.constant([tokenizer.encode(s) for s in choices])[None, :]  # Batch size 1, 2 choices
        outputs = model(input_ids)
        classification_scores = outputs[0]

        """
thomwolf's avatar
thomwolf committed
984
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
985
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
986
987
988
989
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
990
991
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
            assert len(inputs) <= 6, "Too many inputs."
thomwolf's avatar
thomwolf committed
992
        elif isinstance(inputs, dict):
993
994
995
996
997
998
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
999
            assert len(inputs) <= 6, "Too many inputs."
thomwolf's avatar
thomwolf committed
1000
1001
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
1002

1003
        if input_ids is not None:
1004
1005
            num_choices = shape_list(input_ids)[1]
            seq_length = shape_list(input_ids)[2]
1006
        else:
1007
1008
            num_choices = shape_list(inputs_embeds)[1]
            seq_length = shape_list(inputs_embeds)[2]
thomwolf's avatar
thomwolf committed
1009

1010
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
thomwolf's avatar
thomwolf committed
1011
1012
1013
1014
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None

1015
1016
1017
1018
1019
1020
1021
1022
        flat_inputs = [
            flat_input_ids,
            flat_attention_mask,
            flat_token_type_ids,
            flat_position_ids,
            head_mask,
            inputs_embeds,
        ]
thomwolf's avatar
thomwolf committed
1023
1024
1025
1026
1027

        outputs = self.bert(flat_inputs, training=training)

        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
1028
        pooled_output = self.dropout(pooled_output, training=training)
thomwolf's avatar
thomwolf committed
1029
1030
1031
1032
1033
1034
1035
1036
        logits = self.classifier(pooled_output)
        reshaped_logits = tf.reshape(logits, (-1, num_choices))

        outputs = (reshaped_logits,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # reshaped_logits, (hidden_states), (attentions)


1037
1038
@add_start_docstrings(
    """Bert Model with a token classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
1039
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1040
1041
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1042
class TFBertForTokenClassification(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.BertConfig`) and inputs:
        scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
thomwolf's avatar
thomwolf committed
1059
            Classification scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1060
1061
1062
1063
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
1064
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
1065
1066
1067
1068
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
1069
1070
1071
1072
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
1073
        import tensorflow as tf
1074
        from transformers import BertTokenizer, TFBertForTokenClassification
thomwolf's avatar
thomwolf committed
1075

thomwolf's avatar
thomwolf committed
1076
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
1077
        model = TFBertForTokenClassification.from_pretrained('bert-base-uncased')
1078
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
1079
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
1080
        scores = outputs[0]
thomwolf's avatar
thomwolf committed
1081

Lysandre's avatar
Lysandre committed
1082
        """
thomwolf's avatar
thomwolf committed
1083
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
1084
1085
1086

        sequence_output = outputs[0]

1087
        sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
thomwolf's avatar
thomwolf committed
1088
1089
1090
1091
1092
1093
1094
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # scores, (hidden_states), (attentions)


1095
1096
@add_start_docstrings(
    """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
thomwolf's avatar
thomwolf committed
1097
    the hidden-states output to compute `span start logits` and `span end logits`). """,
1098
1099
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1100
class TFBertForQuestionAnswering(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.qa_outputs = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
        )

    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def call(self, inputs, **kwargs):
        r"""
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.BertConfig`) and inputs:
        start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
thomwolf's avatar
thomwolf committed
1116
            Span-start scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1117
        end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
thomwolf's avatar
thomwolf committed
1118
            Span-end scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1119
1120
1121
1122
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
1123
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
1124
1125
1126
1127
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
1128
1129
1130
1131
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
1132
        import tensorflow as tf
1133
        from transformers import BertTokenizer, TFBertForQuestionAnswering
thomwolf's avatar
thomwolf committed
1134

thomwolf's avatar
thomwolf committed
1135
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
1136
        model = TFBertForQuestionAnswering.from_pretrained('bert-base-uncased')
1137
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
1138
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
1139
        start_scores, end_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
1140

Lysandre's avatar
Lysandre committed
1141
        """
thomwolf's avatar
thomwolf committed
1142
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)

        outputs = (start_logits, end_logits,) + outputs[2:]

        return outputs  # start_logits, end_logits, (hidden_states), (attentions)