modeling_tf_bert.py 59.6 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 BERT model. """


Sylvain Gugger's avatar
Sylvain Gugger committed
19
20
from dataclasses import dataclass
from typing import Optional, Tuple
thomwolf's avatar
thomwolf committed
21
22
23
24
25

import numpy as np
import tensorflow as tf

from .configuration_bert import BertConfig
26
27
from .file_utils import (
    MULTIPLE_CHOICE_DUMMY_INPUTS,
Sylvain Gugger's avatar
Sylvain Gugger committed
28
    ModelOutput,
29
30
31
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
Sylvain Gugger's avatar
Sylvain Gugger committed
32
33
34
35
36
37
38
39
40
41
42
43
    replace_return_docstrings,
)
from .modeling_tf_outputs import (
    TFBaseModelOutput,
    TFBaseModelOutputWithPooling,
    TFCausalLMOutput,
    TFMaskedLMOutput,
    TFMultipleChoiceModelOutput,
    TFNextSentencePredictorOutput,
    TFQuestionAnsweringModelOutput,
    TFSequenceClassifierOutput,
    TFTokenClassifierOutput,
44
)
Julien Plu's avatar
Julien Plu committed
45
from .modeling_tf_utils import (
46
47
    TFCausalLanguageModelingLoss,
    TFMaskedLanguageModelingLoss,
Julien Plu's avatar
Julien Plu committed
48
49
50
51
52
53
54
55
56
    TFMultipleChoiceLoss,
    TFPreTrainedModel,
    TFQuestionAnsweringLoss,
    TFSequenceClassificationLoss,
    TFTokenClassificationLoss,
    get_initializer,
    keras_serializable,
    shape_list,
)
57
from .tokenization_utils import BatchEncoding
Lysandre Debut's avatar
Lysandre Debut committed
58
from .utils import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
59

thomwolf's avatar
thomwolf committed
60

Lysandre Debut's avatar
Lysandre Debut committed
61
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
62

Sylvain Gugger's avatar
Sylvain Gugger committed
63
_CONFIG_FOR_DOC = "BertConfig"
64
_TOKENIZER_FOR_DOC = "BertTokenizer"
thomwolf's avatar
thomwolf committed
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "bert-base-uncased",
    "bert-large-uncased",
    "bert-base-cased",
    "bert-large-cased",
    "bert-base-multilingual-uncased",
    "bert-base-multilingual-cased",
    "bert-base-chinese",
    "bert-base-german-cased",
    "bert-large-uncased-whole-word-masking",
    "bert-large-cased-whole-word-masking",
    "bert-large-uncased-whole-word-masking-finetuned-squad",
    "bert-large-cased-whole-word-masking-finetuned-squad",
    "bert-base-cased-finetuned-mrpc",
    "cl-tohoku/bert-base-japanese",
    "cl-tohoku/bert-base-japanese-whole-word-masking",
    "cl-tohoku/bert-base-japanese-char",
    "cl-tohoku/bert-base-japanese-char-whole-word-masking",
    "TurkuNLP/bert-base-finnish-cased-v1",
    "TurkuNLP/bert-base-finnish-uncased-v1",
    "wietsedv/bert-base-dutch-cased",
    # See all BERT models at https://huggingface.co/models?filter=bert
]
thomwolf's avatar
thomwolf committed
89
90
91


def gelu(x):
Lysandre's avatar
Lysandre committed
92
    """Gaussian Error Linear Unit.
Santiago Castro's avatar
Santiago Castro committed
93
    Original Implementation of the gelu activation function in Google Bert repo when initially created.
thomwolf's avatar
thomwolf committed
94
95
96
97
98
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
Julien Plu's avatar
Julien Plu committed
99

thomwolf's avatar
thomwolf committed
100
101
    return x * cdf

102

thomwolf's avatar
thomwolf committed
103
def gelu_new(x):
thomwolf's avatar
thomwolf committed
104
105
106
107
108
109
110
111
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
112
    cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
Julien Plu's avatar
Julien Plu committed
113

thomwolf's avatar
thomwolf committed
114
115
    return x * cdf

116

thomwolf's avatar
thomwolf committed
117
118
119
120
def swish(x):
    return x * tf.sigmoid(x)


121
122
123
124
125
126
ACT2FN = {
    "gelu": tf.keras.layers.Activation(gelu),
    "relu": tf.keras.activations.relu,
    "swish": tf.keras.layers.Activation(swish),
    "gelu_new": tf.keras.layers.Activation(gelu_new),
}
thomwolf's avatar
thomwolf committed
127
128
129


class TFBertEmbeddings(tf.keras.layers.Layer):
Lysandre's avatar
Lysandre committed
130
    """Construct the embeddings from word, position and token_type embeddings."""
131

thomwolf's avatar
thomwolf committed
132
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
133
        super().__init__(**kwargs)
134
135
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size
thomwolf's avatar
thomwolf committed
136
        self.initializer_range = config.initializer_range
137
138
139
140
141
142
143
144
145
146
147
148
        self.position_embeddings = tf.keras.layers.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="position_embeddings",
        )
        self.token_type_embeddings = tf.keras.layers.Embedding(
            config.type_vocab_size,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="token_type_embeddings",
        )
thomwolf's avatar
thomwolf committed
149
150
151

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
152
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
153
154
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

155
156
157
158
159
160
161
162
    def build(self, input_shape):
        """Build shared word embedding layer """
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight",
                shape=[self.vocab_size, self.hidden_size],
163
164
                initializer=get_initializer(self.initializer_range),
            )
Julien Chaumond's avatar
Julien Chaumond committed
165
        super().build(input_shape)
166

Julien Plu's avatar
Julien Plu committed
167
168
169
170
171
172
173
174
175
    def call(
        self,
        input_ids=None,
        position_ids=None,
        token_type_ids=None,
        inputs_embeds=None,
        mode="embedding",
        training=False,
    ):
176
177
178
179
180
181
182
183
184
185
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
186

187
188
189
190
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
Julien Plu's avatar
Julien Plu committed
191
            return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
192
        elif mode == "linear":
Julien Plu's avatar
Julien Plu committed
193
            return self._linear(input_ids)
194
195
196
        else:
            raise ValueError("mode {} is not valid.".format(mode))

Julien Plu's avatar
Julien Plu committed
197
    def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
198
        """Applies embedding based on inputs tensor."""
Julien Plu's avatar
Julien Plu committed
199
        assert not (input_ids is None and inputs_embeds is None)
thomwolf's avatar
thomwolf committed
200

201
        if input_ids is not None:
202
            input_shape = shape_list(input_ids)
203
        else:
204
            input_shape = shape_list(inputs_embeds)[:-1]
205

206
        seq_length = input_shape[1]
Julien Plu's avatar
Julien Plu committed
207

thomwolf's avatar
thomwolf committed
208
209
        if position_ids is None:
            position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
Julien Plu's avatar
Julien Plu committed
210

thomwolf's avatar
thomwolf committed
211
        if token_type_ids is None:
212
            token_type_ids = tf.fill(input_shape, 0)
thomwolf's avatar
thomwolf committed
213

214
215
        if inputs_embeds is None:
            inputs_embeds = tf.gather(self.word_embeddings, input_ids)
Julien Plu's avatar
Julien Plu committed
216

217
218
        position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
        token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
219
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
thomwolf's avatar
thomwolf committed
220
        embeddings = self.LayerNorm(embeddings)
thomwolf's avatar
thomwolf committed
221
        embeddings = self.dropout(embeddings, training=training)
Julien Plu's avatar
Julien Plu committed
222

thomwolf's avatar
thomwolf committed
223
224
        return embeddings

225
226
    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
Lysandre's avatar
Lysandre committed
227
228
229
230
        Args:
            inputs: A float32 tensor with shape [batch_size, length, hidden_size]
        Returns:
            float32 tensor with shape [batch_size, length, vocab_size].
231
        """
232
233
        batch_size = shape_list(inputs)[0]
        length = shape_list(inputs)[1]
234
235
236
237
238
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])

thomwolf's avatar
thomwolf committed
239
240
241

class TFBertSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
242
        super().__init__(**kwargs)
Julien Plu's avatar
Julien Plu committed
243

thomwolf's avatar
thomwolf committed
244
245
246
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
247
248
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
thomwolf's avatar
thomwolf committed
249
250
251
252
253

        self.num_attention_heads = config.num_attention_heads
        assert config.hidden_size % config.num_attention_heads == 0
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
254
255
256
257
258
259
260
261
262
        self.query = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
thomwolf's avatar
thomwolf committed
263
264
265
266
267
        self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))

Julien Plu's avatar
Julien Plu committed
268
        return tf.transpose(x, perm=[0, 2, 1, 3])
thomwolf's avatar
thomwolf committed
269

Julien Plu's avatar
Julien Plu committed
270
    def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
271
        batch_size = shape_list(hidden_states)[0]
thomwolf's avatar
thomwolf committed
272
273
274
275
276
277
278
279
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        # Take the dot product between "query" and "key" to get the raw attention scores.
280
281
282
        attention_scores = tf.matmul(
            query_layer, key_layer, transpose_b=True
        )  # (batch size, num_heads, seq_len_q, seq_len_k)
283
        dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype)  # scale attention_scores
thomwolf's avatar
thomwolf committed
284
        attention_scores = attention_scores / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
285
286
287
288

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
            attention_scores = attention_scores + attention_mask
thomwolf's avatar
thomwolf committed
289
290
291
292

        # Normalize the attention scores to probabilities.
        attention_probs = tf.nn.softmax(attention_scores, axis=-1)

thomwolf's avatar
thomwolf committed
293
294
295
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs, training=training)
thomwolf's avatar
thomwolf committed
296
297
298
299
300
301
302

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = tf.matmul(attention_probs, value_layer)
        context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
303
304
305
        context_layer = tf.reshape(
            context_layer, (batch_size, -1, self.all_head_size)
        )  # (batch_size, seq_len_q, all_head_size)
Julien Plu's avatar
Julien Plu committed
306
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
307

thomwolf's avatar
thomwolf committed
308
309
310
311
312
        return outputs


class TFBertSelfOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
313
        super().__init__(**kwargs)
314
315
316
317
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
318
319
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

Julien Plu's avatar
Julien Plu committed
320
    def call(self, hidden_states, input_tensor, training=False):
thomwolf's avatar
thomwolf committed
321
        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
322
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
323
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
Julien Plu's avatar
Julien Plu committed
324

thomwolf's avatar
thomwolf committed
325
326
327
328
329
        return hidden_states


class TFBertAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
330
        super().__init__(**kwargs)
331
332
        self.self_attention = TFBertSelfAttention(config, name="self")
        self.dense_output = TFBertSelfOutput(config, name="output")
thomwolf's avatar
thomwolf committed
333
334
335
336

    def prune_heads(self, heads):
        raise NotImplementedError

Julien Plu's avatar
Julien Plu committed
337
    def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
338
        self_outputs = self.self_attention(
Julien Plu's avatar
Julien Plu committed
339
            input_tensor, attention_mask, head_mask, output_attentions, training=training
340
        )
Julien Plu's avatar
Julien Plu committed
341
        attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
thomwolf's avatar
thomwolf committed
342
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
Julien Plu's avatar
Julien Plu committed
343

thomwolf's avatar
thomwolf committed
344
345
346
347
348
        return outputs


class TFBertIntermediate(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
349
        super().__init__(**kwargs)
350
351
352
        self.dense = tf.keras.layers.Dense(
            config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
Julien Plu's avatar
Julien Plu committed
353

354
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
355
356
357
358
359
360
361
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
Julien Plu's avatar
Julien Plu committed
362

thomwolf's avatar
thomwolf committed
363
364
365
366
367
        return hidden_states


class TFBertOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
368
        super().__init__(**kwargs)
369
370
371
372
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
373
374
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

Julien Plu's avatar
Julien Plu committed
375
    def call(self, hidden_states, input_tensor, training=False):
thomwolf's avatar
thomwolf committed
376
        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
377
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
378
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
Julien Plu's avatar
Julien Plu committed
379

thomwolf's avatar
thomwolf committed
380
381
382
383
384
        return hidden_states


class TFBertLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
385
        super().__init__(**kwargs)
386
387
388
        self.attention = TFBertAttention(config, name="attention")
        self.intermediate = TFBertIntermediate(config, name="intermediate")
        self.bert_output = TFBertOutput(config, name="output")
thomwolf's avatar
thomwolf committed
389

Julien Plu's avatar
Julien Plu committed
390
    def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
391
        attention_outputs = self.attention(
Julien Plu's avatar
Julien Plu committed
392
            hidden_states, attention_mask, head_mask, output_attentions, training=training
393
        )
thomwolf's avatar
thomwolf committed
394
395
        attention_output = attention_outputs[0]
        intermediate_output = self.intermediate(attention_output)
Julien Plu's avatar
Julien Plu committed
396
        layer_output = self.bert_output(intermediate_output, attention_output, training=training)
thomwolf's avatar
thomwolf committed
397
        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them
Julien Plu's avatar
Julien Plu committed
398

thomwolf's avatar
thomwolf committed
399
400
401
402
403
        return outputs


class TFBertEncoder(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
404
        super().__init__(**kwargs)
405
        self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
thomwolf's avatar
thomwolf committed
406

Sylvain Gugger's avatar
Sylvain Gugger committed
407
408
409
410
411
412
413
414
415
416
417
418
    def call(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        output_attentions,
        output_hidden_states,
        return_dict,
        training=False,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
Julien Plu's avatar
Julien Plu committed
419

thomwolf's avatar
thomwolf committed
420
        for i, layer_module in enumerate(self.layer):
Julien Plu's avatar
Julien Plu committed
421
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
422
423
                all_hidden_states = all_hidden_states + (hidden_states,)

424
            layer_outputs = layer_module(
Julien Plu's avatar
Julien Plu committed
425
                hidden_states, attention_mask, head_mask[i], output_attentions, training=training
426
            )
thomwolf's avatar
thomwolf committed
427
428
            hidden_states = layer_outputs[0]

Julien Plu's avatar
Julien Plu committed
429
            if output_attentions:
thomwolf's avatar
thomwolf committed
430
431
432
                all_attentions = all_attentions + (layer_outputs[1],)

        # Add last layer
Julien Plu's avatar
Julien Plu committed
433
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
434
435
            all_hidden_states = all_hidden_states + (hidden_states,)

Sylvain Gugger's avatar
Sylvain Gugger committed
436
437
438
439
440
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return TFBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
thomwolf's avatar
thomwolf committed
441
442
443
444


class TFBertPooler(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
445
        super().__init__(**kwargs)
446
447
448
449
450
451
        self.dense = tf.keras.layers.Dense(
            config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )
thomwolf's avatar
thomwolf committed
452
453
454
455
456
457

    def call(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
Julien Plu's avatar
Julien Plu committed
458

thomwolf's avatar
thomwolf committed
459
460
461
462
463
        return pooled_output


class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
464
        super().__init__(**kwargs)
465
466
467
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
Julien Plu's avatar
Julien Plu committed
468

469
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
470
471
472
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
Julien Plu's avatar
Julien Plu committed
473

474
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
475
476
477
478
479

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
Julien Plu's avatar
Julien Plu committed
480

thomwolf's avatar
thomwolf committed
481
482
483
484
        return hidden_states


class TFBertLMPredictionHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
485
    def __init__(self, config, input_embeddings, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
486
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
487
        self.vocab_size = config.vocab_size
488
        self.transform = TFBertPredictionHeadTransform(config, name="transform")
thomwolf's avatar
thomwolf committed
489
490
491

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
thomwolf's avatar
thomwolf committed
492
        self.input_embeddings = input_embeddings
thomwolf's avatar
thomwolf committed
493
494

    def build(self, input_shape):
495
        self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
Julien Chaumond's avatar
Julien Chaumond committed
496
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
497
498
499

    def call(self, hidden_states):
        hidden_states = self.transform(hidden_states)
thomwolf's avatar
thomwolf committed
500
501
        hidden_states = self.input_embeddings(hidden_states, mode="linear")
        hidden_states = hidden_states + self.bias
Julien Plu's avatar
Julien Plu committed
502

thomwolf's avatar
thomwolf committed
503
504
505
506
        return hidden_states


class TFBertMLMHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
507
    def __init__(self, config, input_embeddings, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
508
        super().__init__(**kwargs)
509
        self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
thomwolf's avatar
thomwolf committed
510
511
512

    def call(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
Julien Plu's avatar
Julien Plu committed
513

thomwolf's avatar
thomwolf committed
514
515
516
517
518
        return prediction_scores


class TFBertNSPHead(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
519
        super().__init__(**kwargs)
520
521
522
        self.seq_relationship = tf.keras.layers.Dense(
            2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
        )
thomwolf's avatar
thomwolf committed
523
524
525

    def call(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
Julien Plu's avatar
Julien Plu committed
526

thomwolf's avatar
thomwolf committed
527
528
529
        return seq_relationship_score


530
531
@keras_serializable
class TFBertMainLayer(tf.keras.layers.Layer):
532
533
    config_class = BertConfig

thomwolf's avatar
thomwolf committed
534
    def __init__(self, config, **kwargs):
535
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
536
        self.num_hidden_layers = config.num_hidden_layers
537
        self.initializer_range = config.initializer_range
538
        self.output_attentions = config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
539
        self.output_hidden_states = config.output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
540
        self.return_dict = config.use_return_dict
541
542
543
        self.embeddings = TFBertEmbeddings(config, name="embeddings")
        self.encoder = TFBertEncoder(config, name="encoder")
        self.pooler = TFBertPooler(config, name="pooler")
thomwolf's avatar
thomwolf committed
544

545
546
547
    def get_input_embeddings(self):
        return self.embeddings

548
549
550
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
        self.embeddings.vocab_size = value.shape[0]
thomwolf's avatar
thomwolf committed
551
552

    def _prune_heads(self, heads_to_prune):
Lysandre's avatar
Lysandre committed
553
554
555
        """Prunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        See base class PreTrainedModel
thomwolf's avatar
thomwolf committed
556
557
558
        """
        raise NotImplementedError

559
560
561
562
563
564
565
566
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
567
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
568
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
569
        return_dict=None,
570
571
        training=False,
    ):
thomwolf's avatar
thomwolf committed
572
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
573
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
574
575
576
577
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
578
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
579
            output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
Joseph Liu's avatar
Joseph Liu committed
580
            output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
581
582
            return_dict = inputs[8] if len(inputs) > 8 else return_dict
            assert len(inputs) <= 9, "Too many inputs."
583
        elif isinstance(inputs, (dict, BatchEncoding)):
584
585
586
587
588
589
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
590
            output_attentions = inputs.get("output_attentions", output_attentions)
Joseph Liu's avatar
Joseph Liu committed
591
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
Sylvain Gugger's avatar
Sylvain Gugger committed
592
593
            return_dict = inputs.get("return_dict", return_dict)
            assert len(inputs) <= 9, "Too many inputs."
thomwolf's avatar
thomwolf committed
594
595
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
596

597
        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
Joseph Liu's avatar
Joseph Liu committed
598
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
599
        return_dict = return_dict if return_dict is not None else self.return_dict
600

601
602
603
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
604
            input_shape = shape_list(input_ids)
605
        elif inputs_embeds is not None:
606
            input_shape = shape_list(inputs_embeds)[:-1]
607
608
609
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

thomwolf's avatar
thomwolf committed
610
        if attention_mask is None:
611
            attention_mask = tf.fill(input_shape, 1)
thomwolf's avatar
thomwolf committed
612
        if token_type_ids is None:
613
            token_type_ids = tf.fill(input_shape, 0)
thomwolf's avatar
thomwolf committed
614

615
616
        embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)

thomwolf's avatar
thomwolf committed
617
618
619
620
621
622
623
624
625
626
627
628
629
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.

630
        extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
thomwolf's avatar
thomwolf committed
631
632
633
634
635
636
637
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
638
        if head_mask is not None:
thomwolf's avatar
thomwolf committed
639
640
641
642
643
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)

644
        encoder_outputs = self.encoder(
Julien Plu's avatar
Julien Plu committed
645
646
647
648
649
            embedding_output,
            extended_attention_mask,
            head_mask,
            output_attentions,
            output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
650
            return_dict,
Joseph Liu's avatar
Joseph Liu committed
651
            training=training,
652
        )
thomwolf's avatar
thomwolf committed
653
654
655

        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)
Julien Plu's avatar
Julien Plu committed
656

Sylvain Gugger's avatar
Sylvain Gugger committed
657
        if not return_dict:
Lysandre's avatar
Lysandre committed
658
659
660
661
            return (
                sequence_output,
                pooled_output,
            ) + encoder_outputs[1:]
Sylvain Gugger's avatar
Sylvain Gugger committed
662
663
664
665
666
667
668

        return TFBaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
669

thomwolf's avatar
thomwolf committed
670

thomwolf's avatar
thomwolf committed
671
class TFBertPreTrainedModel(TFPreTrainedModel):
Lysandre's avatar
Lysandre committed
672
673
    """An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
674
    """
675

thomwolf's avatar
thomwolf committed
676
677
678
679
    config_class = BertConfig
    base_model_prefix = "bert"


Sylvain Gugger's avatar
Sylvain Gugger committed
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
@dataclass
class TFBertForPreTrainingOutput(ModelOutput):
    """
    Output type of :class:`~transformers.TFBertForPreTrainingModel`.

    Args:
        prediction_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False
            continuation before SoftMax).
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    prediction_logits: tf.Tensor = None
    seq_relationship_logits: tf.Tensor = None
    hidden_states: Optional[Tuple[tf.Tensor]] = None
    attentions: Optional[Tuple[tf.Tensor]] = None


Lysandre's avatar
Lysandre committed
710
BERT_START_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
711
    This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
Lysandre's avatar
Lysandre committed
712
    Use it as a regular TF 2.0 Keras Model and
thomwolf's avatar
thomwolf committed
713
714
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.

Lysandre's avatar
Lysandre committed
715
    .. note::
Lysandre's avatar
Lysandre committed
716

thomwolf's avatar
thomwolf committed
717
718
719
720
721
        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

Lysandre's avatar
Lysandre committed
722
        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
Lysandre's avatar
Lysandre committed
723
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
thomwolf's avatar
thomwolf committed
724

Lysandre's avatar
Lysandre committed
725
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
Lysandre's avatar
Lysandre committed
726
        in the first positional argument :
thomwolf's avatar
thomwolf committed
727

Lysandre's avatar
Lysandre committed
728
        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
thomwolf's avatar
thomwolf committed
729
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
Lysandre's avatar
Lysandre committed
730
731
732
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
thomwolf's avatar
thomwolf committed
733
734

    Parameters:
735
        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
thomwolf's avatar
thomwolf committed
736
            Initializing with a config file does not load the weights associated with the model, only the configuration.
737
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
738
739
740
"""

BERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
741
    Args:
742
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`):
Lysandre's avatar
Lysandre committed
743
744
            Indices of input sequence tokens in the vocabulary.

745
746
            Indices can be obtained using :class:`transformers.BertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
747
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
Lysandre's avatar
Lysandre committed
748

Lysandre's avatar
Lysandre committed
749
            `What are input IDs? <../glossary.html#input-ids>`__
750
        attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
751
752
753
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
754

Lysandre's avatar
Lysandre committed
755
            `What are attention masks? <../glossary.html#attention-mask>`__
756
        token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
757
758
759
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
Lysandre's avatar
Lysandre committed
760
761

            `What are token type IDs? <../glossary.html#token-type-ids>`__
762
        position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
763
764
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
765

Lysandre's avatar
Lysandre committed
766
767
            `What are position IDs? <../glossary.html#position-ids>`__
        head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
768
769
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
Lysandre committed
770
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
Lysandre's avatar
Lysandre committed
771
        inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
Lysandre committed
772
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
773
774
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
Lysandre's avatar
Lysandre committed
775
776
777
        training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
            Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
            (if set to :obj:`False`) for evaluation.
ZhuBaohe's avatar
ZhuBaohe committed
778
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
779
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
Sylvain Gugger's avatar
Sylvain Gugger committed
780
781
782
783
784
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
        return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
            plain tuple.
thomwolf's avatar
thomwolf committed
785
786
"""

787
788
789
790
791

@add_start_docstrings(
    "The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
792
class TFBertModel(TFBertPreTrainedModel):
thomwolf's avatar
thomwolf committed
793
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
794
        super().__init__(config, *inputs, **kwargs)
795
        self.bert = TFBertMainLayer(config, name="bert")
thomwolf's avatar
thomwolf committed
796

797
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Sylvain Gugger's avatar
Sylvain Gugger committed
798
799
800
801
802
803
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFBaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
thomwolf's avatar
thomwolf committed
804
805
    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
806
807
808
        return outputs


809
810
@add_start_docstrings(
    """Bert Model with two heads on top as done during the pre-training:
thomwolf's avatar
thomwolf committed
811
    a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
812
813
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
814
class TFBertForPreTraining(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
815
816
817
818
819
820
821
822
823
824
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.nsp = TFBertNSPHead(config, name="nsp___cls")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

825
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Sylvain Gugger's avatar
Sylvain Gugger committed
826
    @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Lysandre's avatar
Lysandre committed
827
828
    def call(self, inputs, **kwargs):
        r"""
Lysandre's avatar
Lysandre committed
829
        Return:
thomwolf's avatar
thomwolf committed
830

Lysandre's avatar
Lysandre committed
831
        Examples::
thomwolf's avatar
thomwolf committed
832

Lysandre's avatar
Lysandre committed
833
834
            import tensorflow as tf
            from transformers import BertTokenizer, TFBertForPreTraining
thomwolf's avatar
thomwolf committed
835

Lysandre's avatar
Lysandre committed
836
837
838
839
840
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            model = TFBertForPreTraining.from_pretrained('bert-base-uncased')
            input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
            outputs = model(input_ids)
            prediction_scores, seq_relationship_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
841

Lysandre's avatar
Lysandre committed
842
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
843
844
        return_dict = kwargs.get("return_dict")
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
thomwolf's avatar
thomwolf committed
845
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
846
847

        sequence_output, pooled_output = outputs[:2]
848
        prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
849
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
850

Sylvain Gugger's avatar
Sylvain Gugger committed
851
852
        if not return_dict:
            return (prediction_scores, seq_relationship_score) + outputs[2:]
thomwolf's avatar
thomwolf committed
853

Sylvain Gugger's avatar
Sylvain Gugger committed
854
855
856
857
858
859
        return TFBertForPreTrainingOutput(
            prediction_logits=prediction_scores,
            seq_relationship_logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
860
861


Lysandre's avatar
Lysandre committed
862
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
863
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
Lysandre's avatar
Lysandre committed
864
865
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
Lysandre Debut's avatar
Lysandre Debut committed
866
867

        if config.is_decoder:
868
            logger.warning(
Lysandre Debut's avatar
Lysandre Debut committed
869
870
871
                "If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )
Lysandre's avatar
Lysandre committed
872
873
874
875
876
877
878

        self.bert = TFBertMainLayer(config, name="bert")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

879
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Sylvain Gugger's avatar
Sylvain Gugger committed
880
881
882
883
884
885
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFMaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
886
887
888
889
890
891
892
893
894
895
    def call(
        self,
        inputs=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
896
        return_dict=None,
897
898
899
        labels=None,
        training=False,
    ):
Lysandre's avatar
Lysandre committed
900
        r"""
901
902
903
904
905
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
906
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
907
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
908
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
909
910
911
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
912
913
914
915
916
917
918
919
920
921
922
923
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

        outputs = self.bert(
            inputs,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
924
            return_dict=return_dict,
925
926
            training=training,
        )
thomwolf's avatar
thomwolf committed
927
928

        sequence_output = outputs[0]
929
        prediction_scores = self.mlm(sequence_output, training=training)
thomwolf's avatar
thomwolf committed
930

Sylvain Gugger's avatar
Sylvain Gugger committed
931
        loss = None if labels is None else self.compute_loss(labels, prediction_scores)
thomwolf's avatar
thomwolf committed
932

Sylvain Gugger's avatar
Sylvain Gugger committed
933
934
935
        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
936

Sylvain Gugger's avatar
Sylvain Gugger committed
937
        return TFMaskedLMOutput(
Lysandre's avatar
Lysandre committed
938
939
940
941
            loss=loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
942
        )
943
944
945
946
947


class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
Lysandre Debut's avatar
Lysandre Debut committed
948
949

        if not config.is_decoder:
950
            logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")
951
952
953
954
955
956
957

        self.bert = TFBertMainLayer(config, name="bert")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

Sylvain Gugger's avatar
Sylvain Gugger committed
958
959
960
961
962
963
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFCausalLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
964
965
966
967
968
969
970
971
972
973
    def call(
        self,
        inputs=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
974
        return_dict=None,
975
976
977
978
979
980
981
982
        labels=None,
        training=False,
    ):
        r"""
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the cross entropy classification loss.
            Indices should be in ``[0, ..., config.vocab_size - 1]``.
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
983
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
984
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
985
986
987
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
988
989
990
991
992
993
994
995
996
997
998
999
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

        outputs = self.bert(
            inputs,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1000
            return_dict=return_dict,
1001
1002
1003
1004
1005
1006
            training=training,
        )

        sequence_output = outputs[0]
        logits = self.mlm(sequence_output, training=training)

Sylvain Gugger's avatar
Sylvain Gugger committed
1007
        loss = None
1008
1009
1010
1011
1012
1013
        if labels is not None:
            # shift labels to the left and cut last logit token
            logits = logits[:, :-1]
            labels = labels[:, 1:]
            loss = self.compute_loss(labels, logits)

Sylvain Gugger's avatar
Sylvain Gugger committed
1014
1015
1016
1017
1018
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFCausalLMOutput(
Lysandre's avatar
Lysandre committed
1019
1020
1021
1022
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1023
        )
thomwolf's avatar
thomwolf committed
1024
1025


1026
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1027
1028
    """Bert Model with a `next sentence prediction (classification)` head on top. """,
    BERT_START_DOCSTRING,
1029
)
thomwolf's avatar
thomwolf committed
1030
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
1031
1032
1033
1034
1035
1036
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.nsp = TFBertNSPHead(config, name="nsp___cls")

1037
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1038
    @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
Lysandre's avatar
Lysandre committed
1039
1040
    def call(self, inputs, **kwargs):
        r"""
Lysandre's avatar
Lysandre committed
1041
        Return:
thomwolf's avatar
thomwolf committed
1042

Lysandre's avatar
Lysandre committed
1043
        Examples::
thomwolf's avatar
thomwolf committed
1044

Lysandre's avatar
Lysandre committed
1045
1046
            import tensorflow as tf
            from transformers import BertTokenizer, TFBertForNextSentencePrediction
thomwolf's avatar
thomwolf committed
1047

Lysandre's avatar
Lysandre committed
1048
1049
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
1050

Lysandre's avatar
Lysandre committed
1051
1052
1053
            prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
            next_sentence = "The sky is blue due to the shorter wavelength of blue light."
            encoding = tokenizer(prompt, next_sentence, return_tensors='tf')
1054

Lysandre's avatar
Lysandre committed
1055
1056
            logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
            assert logits[0][0] < logits[0][1] # the next sentence was random
Lysandre's avatar
Lysandre committed
1057
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1058
1059
        return_dict = kwargs.get("return_dict")
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
thomwolf's avatar
thomwolf committed
1060
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
1061
1062

        pooled_output = outputs[1]
1063
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
1064

Sylvain Gugger's avatar
Sylvain Gugger committed
1065
1066
        if not return_dict:
            return (seq_relationship_score,) + outputs[2:]
thomwolf's avatar
thomwolf committed
1067

Sylvain Gugger's avatar
Sylvain Gugger committed
1068
        return TFNextSentencePredictorOutput(
Lysandre's avatar
Lysandre committed
1069
1070
1071
            logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1072
        )
thomwolf's avatar
thomwolf committed
1073
1074


1075
1076
@add_start_docstrings(
    """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
1077
    the pooled output) e.g. for GLUE tasks. """,
1078
1079
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1080
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
Lysandre's avatar
Lysandre committed
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

1091
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1092
1093
1094
1095
1096
1097
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Julien Plu's avatar
Julien Plu committed
1098
1099
    def call(
        self,
1100
        inputs=None,
Julien Plu's avatar
Julien Plu committed
1101
1102
1103
1104
1105
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
1106
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1107
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1108
        return_dict=None,
1109
        labels=None,
Julien Plu's avatar
Julien Plu committed
1110
1111
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1112
        r"""
Julien Plu's avatar
Julien Plu committed
1113
1114
1115
1116
1117
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Lysandre's avatar
Lysandre committed
1118
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1119
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
1120
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1121
1122
1123
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
1124
1125
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)
Julien Plu's avatar
Julien Plu committed
1126
1127

        outputs = self.bert(
1128
            inputs,
Julien Plu's avatar
Julien Plu committed
1129
1130
1131
1132
1133
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1134
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1135
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1136
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
1137
1138
            training=training,
        )
thomwolf's avatar
thomwolf committed
1139
1140
1141

        pooled_output = outputs[1]

Julien Plu's avatar
Julien Plu committed
1142
        pooled_output = self.dropout(pooled_output, training=training)
thomwolf's avatar
thomwolf committed
1143
1144
        logits = self.classifier(pooled_output)

Sylvain Gugger's avatar
Sylvain Gugger committed
1145
        loss = None if labels is None else self.compute_loss(labels, logits)
thomwolf's avatar
thomwolf committed
1146

Sylvain Gugger's avatar
Sylvain Gugger committed
1147
1148
1149
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
Julien Plu's avatar
Julien Plu committed
1150

Sylvain Gugger's avatar
Sylvain Gugger committed
1151
        return TFSequenceClassifierOutput(
Lysandre's avatar
Lysandre committed
1152
1153
1154
1155
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1156
        )
thomwolf's avatar
thomwolf committed
1157
1158


1159
1160
@add_start_docstrings(
    """Bert Model with a multiple choice classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
1161
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1162
1163
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1164
class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
thomwolf's avatar
thomwolf committed
1165
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1166
        super().__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
1167

1168
        self.bert = TFBertMainLayer(config, name="bert")
thomwolf's avatar
thomwolf committed
1169
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
1170
1171
1172
1173
        self.classifier = tf.keras.layers.Dense(
            1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Lysandre's avatar
Lysandre committed
1174
1175
    @property
    def dummy_inputs(self):
Lysandre's avatar
Lysandre committed
1176
        """Dummy inputs to build the network.
Lysandre's avatar
Lysandre committed
1177
1178
1179
1180
1181
1182

        Returns:
            tf.Tensor with dummy inputs
        """
        return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}

1183
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1184
1185
1186
1187
1188
1189
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFMultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1190
1191
1192
1193
1194
1195
1196
1197
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
1198
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1199
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1200
        return_dict=None,
1201
        labels=None,
1202
1203
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1204
        r"""
Julien Plu's avatar
Julien Plu committed
1205
1206
1207
1208
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
Lysandre's avatar
Lysandre committed
1209
        """
thomwolf's avatar
thomwolf committed
1210
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
1211
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
1212
1213
1214
1215
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
1216
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
1217
            output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
1218
            output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
1219
1220
1221
            return_dict = inputs[8] if len(inputs) > 8 else return_dict
            labels = inputs[9] if len(inputs) > 9 else labels
            assert len(inputs) <= 10, "Too many inputs."
Julien Plu's avatar
Julien Plu committed
1222
        elif isinstance(inputs, (dict, BatchEncoding)):
1223
1224
1225
1226
1227
1228
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
1229
            output_attentions = inputs.get("output_attentions", output_attentions)
1230
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
Sylvain Gugger's avatar
Sylvain Gugger committed
1231
            return_dict = inputs.get("return_dict", return_dict)
1232
            labels = inputs.get("labels", labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
1233
            assert len(inputs) <= 10, "Too many inputs."
thomwolf's avatar
thomwolf committed
1234
1235
        else:
            input_ids = inputs
Sylvain Gugger's avatar
Sylvain Gugger committed
1236
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
thomwolf's avatar
thomwolf committed
1237

1238
        if input_ids is not None:
1239
1240
            num_choices = shape_list(input_ids)[1]
            seq_length = shape_list(input_ids)[2]
1241
        else:
1242
1243
            num_choices = shape_list(inputs_embeds)[1]
            seq_length = shape_list(inputs_embeds)[2]
thomwolf's avatar
thomwolf committed
1244

1245
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
thomwolf's avatar
thomwolf committed
1246
1247
1248
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
1249
1250
1251
1252
1253
        flat_inputs_embeds = (
            tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
            if inputs_embeds is not None
            else None
        )
Julien Plu's avatar
Julien Plu committed
1254
        outputs = self.bert(
1255
1256
1257
1258
1259
            flat_input_ids,
            flat_attention_mask,
            flat_token_type_ids,
            flat_position_ids,
            head_mask,
1260
            flat_inputs_embeds,
1261
            output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1262
            output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1263
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
1264
1265
            training=training,
        )
thomwolf's avatar
thomwolf committed
1266
        pooled_output = outputs[1]
thomwolf's avatar
thomwolf committed
1267
        pooled_output = self.dropout(pooled_output, training=training)
thomwolf's avatar
thomwolf committed
1268
1269
1270
        logits = self.classifier(pooled_output)
        reshaped_logits = tf.reshape(logits, (-1, num_choices))

Sylvain Gugger's avatar
Sylvain Gugger committed
1271
        loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
Julien Plu's avatar
Julien Plu committed
1272

Sylvain Gugger's avatar
Sylvain Gugger committed
1273
1274
1275
1276
1277
        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFMultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
1278
1279
1280
1281
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1282
        )
thomwolf's avatar
thomwolf committed
1283
1284


1285
1286
@add_start_docstrings(
    """Bert Model with a token classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
1287
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1288
1289
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1290
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
Lysandre's avatar
Lysandre committed
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Julien Plu's avatar
Julien Plu committed
1301
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
Sylvain Gugger's avatar
Sylvain Gugger committed
1302
1303
1304
1305
1306
1307
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFTokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Julien Plu's avatar
Julien Plu committed
1308
1309
    def call(
        self,
1310
        inputs=None,
Julien Plu's avatar
Julien Plu committed
1311
1312
1313
1314
1315
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
1316
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1317
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1318
        return_dict=None,
1319
        labels=None,
Julien Plu's avatar
Julien Plu committed
1320
1321
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1322
        r"""
Julien Plu's avatar
Julien Plu committed
1323
1324
1325
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.
Lysandre's avatar
Lysandre committed
1326
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1327
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
1328
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1329
1330
1331
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
1332
1333
1334
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

Julien Plu's avatar
Julien Plu committed
1335
        outputs = self.bert(
1336
            inputs,
Julien Plu's avatar
Julien Plu committed
1337
1338
1339
1340
1341
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1342
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1343
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1344
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
1345
1346
            training=training,
        )
thomwolf's avatar
thomwolf committed
1347
1348
1349

        sequence_output = outputs[0]

Julien Plu's avatar
Julien Plu committed
1350
        sequence_output = self.dropout(sequence_output, training=training)
thomwolf's avatar
thomwolf committed
1351
1352
        logits = self.classifier(sequence_output)

Sylvain Gugger's avatar
Sylvain Gugger committed
1353
        loss = None if labels is None else self.compute_loss(labels, logits)
thomwolf's avatar
thomwolf committed
1354

Sylvain Gugger's avatar
Sylvain Gugger committed
1355
1356
1357
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
Julien Plu's avatar
Julien Plu committed
1358

Sylvain Gugger's avatar
Sylvain Gugger committed
1359
        return TFTokenClassifierOutput(
Lysandre's avatar
Lysandre committed
1360
1361
1362
1363
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1364
        )
thomwolf's avatar
thomwolf committed
1365
1366


1367
1368
@add_start_docstrings(
    """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
thomwolf's avatar
thomwolf committed
1369
    the hidden-states output to compute `span start logits` and `span end logits`). """,
1370
1371
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1372
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
Lysandre's avatar
Lysandre committed
1373
1374
1375
1376
1377
1378
1379
1380
1381
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.qa_outputs = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
        )

Julien Plu's avatar
Julien Plu committed
1382
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
Sylvain Gugger's avatar
Sylvain Gugger committed
1383
1384
1385
1386
1387
1388
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFQuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Julien Plu's avatar
Julien Plu committed
1389
1390
    def call(
        self,
1391
        inputs=None,
Julien Plu's avatar
Julien Plu committed
1392
1393
1394
1395
1396
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
1397
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1398
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1399
        return_dict=None,
1400
1401
        start_positions=None,
        end_positions=None,
Julien Plu's avatar
Julien Plu committed
1402
1403
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1404
        r"""
Julien Plu's avatar
Julien Plu committed
1405
1406
1407
1408
1409
1410
1411
1412
        start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
Lysandre's avatar
Lysandre committed
1413
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1414
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
1415
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1416
1417
1418
1419
            start_positions = inputs[9] if len(inputs) > 9 else start_positions
            end_positions = inputs[10] if len(inputs) > 10 else end_positions
            if len(inputs) > 9:
                inputs = inputs[:9]
1420
1421
1422
1423
        elif isinstance(inputs, (dict, BatchEncoding)):
            start_positions = inputs.pop("start_positions", start_positions)
            end_positions = inputs.pop("end_positions", start_positions)

Julien Plu's avatar
Julien Plu committed
1424
        outputs = self.bert(
1425
            inputs,
Julien Plu's avatar
Julien Plu committed
1426
1427
1428
1429
1430
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1431
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1432
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1433
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
1434
1435
            training=training,
        )
thomwolf's avatar
thomwolf committed
1436
1437
1438
1439
1440
1441
1442
1443

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)

Sylvain Gugger's avatar
Sylvain Gugger committed
1444
        loss = None
Julien Plu's avatar
Julien Plu committed
1445
1446
1447
        if start_positions is not None and end_positions is not None:
            labels = {"start_position": start_positions}
            labels["end_position"] = end_positions
Sylvain Gugger's avatar
Sylvain Gugger committed
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
            loss = self.compute_loss(labels, (start_logits, end_logits))

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFQuestionAnsweringModelOutput(
            loss=loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )