modeling_tf_bert.py 59 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 BERT model. """


Sylvain Gugger's avatar
Sylvain Gugger committed
19
20
from dataclasses import dataclass
from typing import Optional, Tuple
thomwolf's avatar
thomwolf committed
21
22
23

import tensorflow as tf

24
from .activations_tf import get_tf_activation
thomwolf's avatar
thomwolf committed
25
from .configuration_bert import BertConfig
26
27
from .file_utils import (
    MULTIPLE_CHOICE_DUMMY_INPUTS,
Sylvain Gugger's avatar
Sylvain Gugger committed
28
    ModelOutput,
29
30
31
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
Sylvain Gugger's avatar
Sylvain Gugger committed
32
33
34
35
36
37
38
39
40
41
42
43
    replace_return_docstrings,
)
from .modeling_tf_outputs import (
    TFBaseModelOutput,
    TFBaseModelOutputWithPooling,
    TFCausalLMOutput,
    TFMaskedLMOutput,
    TFMultipleChoiceModelOutput,
    TFNextSentencePredictorOutput,
    TFQuestionAnsweringModelOutput,
    TFSequenceClassifierOutput,
    TFTokenClassifierOutput,
44
)
Julien Plu's avatar
Julien Plu committed
45
from .modeling_tf_utils import (
46
47
    TFCausalLanguageModelingLoss,
    TFMaskedLanguageModelingLoss,
Julien Plu's avatar
Julien Plu committed
48
49
50
51
52
53
54
55
56
    TFMultipleChoiceLoss,
    TFPreTrainedModel,
    TFQuestionAnsweringLoss,
    TFSequenceClassificationLoss,
    TFTokenClassificationLoss,
    get_initializer,
    keras_serializable,
    shape_list,
)
57
from .tokenization_utils import BatchEncoding
Lysandre Debut's avatar
Lysandre Debut committed
58
from .utils import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
59

thomwolf's avatar
thomwolf committed
60

Lysandre Debut's avatar
Lysandre Debut committed
61
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
62

Sylvain Gugger's avatar
Sylvain Gugger committed
63
_CONFIG_FOR_DOC = "BertConfig"
64
_TOKENIZER_FOR_DOC = "BertTokenizer"
thomwolf's avatar
thomwolf committed
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "bert-base-uncased",
    "bert-large-uncased",
    "bert-base-cased",
    "bert-large-cased",
    "bert-base-multilingual-uncased",
    "bert-base-multilingual-cased",
    "bert-base-chinese",
    "bert-base-german-cased",
    "bert-large-uncased-whole-word-masking",
    "bert-large-cased-whole-word-masking",
    "bert-large-uncased-whole-word-masking-finetuned-squad",
    "bert-large-cased-whole-word-masking-finetuned-squad",
    "bert-base-cased-finetuned-mrpc",
    "cl-tohoku/bert-base-japanese",
    "cl-tohoku/bert-base-japanese-whole-word-masking",
    "cl-tohoku/bert-base-japanese-char",
    "cl-tohoku/bert-base-japanese-char-whole-word-masking",
    "TurkuNLP/bert-base-finnish-cased-v1",
    "TurkuNLP/bert-base-finnish-uncased-v1",
    "wietsedv/bert-base-dutch-cased",
    # See all BERT models at https://huggingface.co/models?filter=bert
]
thomwolf's avatar
thomwolf committed
89
90
91


class TFBertEmbeddings(tf.keras.layers.Layer):
Lysandre's avatar
Lysandre committed
92
    """Construct the embeddings from word, position and token_type embeddings."""
93

thomwolf's avatar
thomwolf committed
94
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
95
        super().__init__(**kwargs)
96

97
98
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size
thomwolf's avatar
thomwolf committed
99
        self.initializer_range = config.initializer_range
100
101
102
103
104
105
106
107
108
109
110
111
        self.position_embeddings = tf.keras.layers.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="position_embeddings",
        )
        self.token_type_embeddings = tf.keras.layers.Embedding(
            config.type_vocab_size,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="token_type_embeddings",
        )
thomwolf's avatar
thomwolf committed
112
113
114

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
115
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
116
117
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

118
119
120
121
122
123
124
125
    def build(self, input_shape):
        """Build shared word embedding layer """
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight",
                shape=[self.vocab_size, self.hidden_size],
126
127
                initializer=get_initializer(self.initializer_range),
            )
128

Julien Chaumond's avatar
Julien Chaumond committed
129
        super().build(input_shape)
130

Julien Plu's avatar
Julien Plu committed
131
132
133
134
135
136
137
138
139
    def call(
        self,
        input_ids=None,
        position_ids=None,
        token_type_ids=None,
        inputs_embeds=None,
        mode="embedding",
        training=False,
    ):
140
141
142
143
144
145
146
147
148
149
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
150

151
152
153
154
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
Julien Plu's avatar
Julien Plu committed
155
            return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
156
        elif mode == "linear":
Julien Plu's avatar
Julien Plu committed
157
            return self._linear(input_ids)
158
159
160
        else:
            raise ValueError("mode {} is not valid.".format(mode))

Julien Plu's avatar
Julien Plu committed
161
    def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
162
        """Applies embedding based on inputs tensor."""
Julien Plu's avatar
Julien Plu committed
163
        assert not (input_ids is None and inputs_embeds is None)
thomwolf's avatar
thomwolf committed
164

165
        if input_ids is not None:
166
            input_shape = shape_list(input_ids)
167
        else:
168
            input_shape = shape_list(inputs_embeds)[:-1]
169

170
        seq_length = input_shape[1]
Julien Plu's avatar
Julien Plu committed
171

thomwolf's avatar
thomwolf committed
172
173
        if position_ids is None:
            position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
Julien Plu's avatar
Julien Plu committed
174

thomwolf's avatar
thomwolf committed
175
        if token_type_ids is None:
176
            token_type_ids = tf.fill(input_shape, 0)
thomwolf's avatar
thomwolf committed
177

178
179
        if inputs_embeds is None:
            inputs_embeds = tf.gather(self.word_embeddings, input_ids)
Julien Plu's avatar
Julien Plu committed
180

181
182
        position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
        token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
183
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
thomwolf's avatar
thomwolf committed
184
        embeddings = self.LayerNorm(embeddings)
thomwolf's avatar
thomwolf committed
185
        embeddings = self.dropout(embeddings, training=training)
Julien Plu's avatar
Julien Plu committed
186

thomwolf's avatar
thomwolf committed
187
188
        return embeddings

189
190
    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
Lysandre's avatar
Lysandre committed
191
192
193
194
        Args:
            inputs: A float32 tensor with shape [batch_size, length, hidden_size]
        Returns:
            float32 tensor with shape [batch_size, length, vocab_size].
195
        """
196
197
        batch_size = shape_list(inputs)[0]
        length = shape_list(inputs)[1]
198
199
200
201
202
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])

thomwolf's avatar
thomwolf committed
203
204
205

class TFBertSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
206
        super().__init__(**kwargs)
Julien Plu's avatar
Julien Plu committed
207

thomwolf's avatar
thomwolf committed
208
209
210
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
211
212
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
thomwolf's avatar
thomwolf committed
213
214
215
216
217

        self.num_attention_heads = config.num_attention_heads
        assert config.hidden_size % config.num_attention_heads == 0
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
218
219
220
221
222
223
224
225
226
        self.query = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
thomwolf's avatar
thomwolf committed
227
228
229
230
231
        self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))

Julien Plu's avatar
Julien Plu committed
232
        return tf.transpose(x, perm=[0, 2, 1, 3])
thomwolf's avatar
thomwolf committed
233

Julien Plu's avatar
Julien Plu committed
234
    def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
235
        batch_size = shape_list(hidden_states)[0]
thomwolf's avatar
thomwolf committed
236
237
238
239
240
241
242
243
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        # Take the dot product between "query" and "key" to get the raw attention scores.
244
245
246
        attention_scores = tf.matmul(
            query_layer, key_layer, transpose_b=True
        )  # (batch size, num_heads, seq_len_q, seq_len_k)
247
        dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype)  # scale attention_scores
thomwolf's avatar
thomwolf committed
248
        attention_scores = attention_scores / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
249
250
251
252

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
            attention_scores = attention_scores + attention_mask
thomwolf's avatar
thomwolf committed
253
254
255
256

        # Normalize the attention scores to probabilities.
        attention_probs = tf.nn.softmax(attention_scores, axis=-1)

thomwolf's avatar
thomwolf committed
257
258
259
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs, training=training)
thomwolf's avatar
thomwolf committed
260
261
262
263
264
265
266

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = tf.matmul(attention_probs, value_layer)
        context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
267
268
269
        context_layer = tf.reshape(
            context_layer, (batch_size, -1, self.all_head_size)
        )  # (batch_size, seq_len_q, all_head_size)
Julien Plu's avatar
Julien Plu committed
270
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
271

thomwolf's avatar
thomwolf committed
272
273
274
275
276
        return outputs


class TFBertSelfOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
277
        super().__init__(**kwargs)
278

279
280
281
282
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
283
284
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

Julien Plu's avatar
Julien Plu committed
285
    def call(self, hidden_states, input_tensor, training=False):
thomwolf's avatar
thomwolf committed
286
        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
287
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
288
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
Julien Plu's avatar
Julien Plu committed
289

thomwolf's avatar
thomwolf committed
290
291
292
293
294
        return hidden_states


class TFBertAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
295
        super().__init__(**kwargs)
296

297
298
        self.self_attention = TFBertSelfAttention(config, name="self")
        self.dense_output = TFBertSelfOutput(config, name="output")
thomwolf's avatar
thomwolf committed
299
300
301
302

    def prune_heads(self, heads):
        raise NotImplementedError

Julien Plu's avatar
Julien Plu committed
303
    def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
304
        self_outputs = self.self_attention(
Julien Plu's avatar
Julien Plu committed
305
            input_tensor, attention_mask, head_mask, output_attentions, training=training
306
        )
Julien Plu's avatar
Julien Plu committed
307
        attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
thomwolf's avatar
thomwolf committed
308
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
Julien Plu's avatar
Julien Plu committed
309

thomwolf's avatar
thomwolf committed
310
311
312
313
314
        return outputs


class TFBertIntermediate(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
315
        super().__init__(**kwargs)
316

317
318
319
        self.dense = tf.keras.layers.Dense(
            config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
Julien Plu's avatar
Julien Plu committed
320

321
        if isinstance(config.hidden_act, str):
322
            self.intermediate_act_fn = get_tf_activation(config.hidden_act)
thomwolf's avatar
thomwolf committed
323
324
325
326
327
328
        else:
            self.intermediate_act_fn = config.hidden_act

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
Julien Plu's avatar
Julien Plu committed
329

thomwolf's avatar
thomwolf committed
330
331
332
333
334
        return hidden_states


class TFBertOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
335
        super().__init__(**kwargs)
336

337
338
339
340
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
341
342
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

Julien Plu's avatar
Julien Plu committed
343
    def call(self, hidden_states, input_tensor, training=False):
thomwolf's avatar
thomwolf committed
344
        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
345
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
346
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
Julien Plu's avatar
Julien Plu committed
347

thomwolf's avatar
thomwolf committed
348
349
350
351
352
        return hidden_states


class TFBertLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
353
        super().__init__(**kwargs)
354

355
356
357
        self.attention = TFBertAttention(config, name="attention")
        self.intermediate = TFBertIntermediate(config, name="intermediate")
        self.bert_output = TFBertOutput(config, name="output")
thomwolf's avatar
thomwolf committed
358

Julien Plu's avatar
Julien Plu committed
359
    def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
360
        attention_outputs = self.attention(
Julien Plu's avatar
Julien Plu committed
361
            hidden_states, attention_mask, head_mask, output_attentions, training=training
362
        )
thomwolf's avatar
thomwolf committed
363
364
        attention_output = attention_outputs[0]
        intermediate_output = self.intermediate(attention_output)
Julien Plu's avatar
Julien Plu committed
365
        layer_output = self.bert_output(intermediate_output, attention_output, training=training)
thomwolf's avatar
thomwolf committed
366
        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them
Julien Plu's avatar
Julien Plu committed
367

thomwolf's avatar
thomwolf committed
368
369
370
371
372
        return outputs


class TFBertEncoder(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
373
        super().__init__(**kwargs)
374

375
        self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
thomwolf's avatar
thomwolf committed
376

Sylvain Gugger's avatar
Sylvain Gugger committed
377
378
379
380
381
382
383
384
385
386
387
388
    def call(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        output_attentions,
        output_hidden_states,
        return_dict,
        training=False,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
Julien Plu's avatar
Julien Plu committed
389

thomwolf's avatar
thomwolf committed
390
        for i, layer_module in enumerate(self.layer):
Julien Plu's avatar
Julien Plu committed
391
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
392
393
                all_hidden_states = all_hidden_states + (hidden_states,)

394
            layer_outputs = layer_module(
Julien Plu's avatar
Julien Plu committed
395
                hidden_states, attention_mask, head_mask[i], output_attentions, training=training
396
            )
thomwolf's avatar
thomwolf committed
397
398
            hidden_states = layer_outputs[0]

Julien Plu's avatar
Julien Plu committed
399
            if output_attentions:
thomwolf's avatar
thomwolf committed
400
401
402
                all_attentions = all_attentions + (layer_outputs[1],)

        # Add last layer
Julien Plu's avatar
Julien Plu committed
403
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
404
405
            all_hidden_states = all_hidden_states + (hidden_states,)

Sylvain Gugger's avatar
Sylvain Gugger committed
406
407
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
408

Sylvain Gugger's avatar
Sylvain Gugger committed
409
410
411
        return TFBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
thomwolf's avatar
thomwolf committed
412
413
414
415


class TFBertPooler(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
416
        super().__init__(**kwargs)
417

418
419
420
421
422
423
        self.dense = tf.keras.layers.Dense(
            config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )
thomwolf's avatar
thomwolf committed
424
425
426
427
428
429

    def call(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
Julien Plu's avatar
Julien Plu committed
430

thomwolf's avatar
thomwolf committed
431
432
433
434
435
        return pooled_output


class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
436
        super().__init__(**kwargs)
437

438
439
440
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
Julien Plu's avatar
Julien Plu committed
441

442
        if isinstance(config.hidden_act, str):
443
            self.transform_act_fn = get_tf_activation(config.hidden_act)
thomwolf's avatar
thomwolf committed
444
445
        else:
            self.transform_act_fn = config.hidden_act
Julien Plu's avatar
Julien Plu committed
446

447
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
448
449
450
451
452

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
Julien Plu's avatar
Julien Plu committed
453

thomwolf's avatar
thomwolf committed
454
455
456
457
        return hidden_states


class TFBertLMPredictionHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
458
    def __init__(self, config, input_embeddings, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
459
        super().__init__(**kwargs)
460

thomwolf's avatar
thomwolf committed
461
        self.vocab_size = config.vocab_size
462
        self.transform = TFBertPredictionHeadTransform(config, name="transform")
thomwolf's avatar
thomwolf committed
463
464
465

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
thomwolf's avatar
thomwolf committed
466
        self.input_embeddings = input_embeddings
thomwolf's avatar
thomwolf committed
467
468

    def build(self, input_shape):
469
        self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
470

Julien Chaumond's avatar
Julien Chaumond committed
471
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
472
473
474

    def call(self, hidden_states):
        hidden_states = self.transform(hidden_states)
thomwolf's avatar
thomwolf committed
475
476
        hidden_states = self.input_embeddings(hidden_states, mode="linear")
        hidden_states = hidden_states + self.bias
Julien Plu's avatar
Julien Plu committed
477

thomwolf's avatar
thomwolf committed
478
479
480
481
        return hidden_states


class TFBertMLMHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
482
    def __init__(self, config, input_embeddings, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
483
        super().__init__(**kwargs)
484

485
        self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
thomwolf's avatar
thomwolf committed
486
487
488

    def call(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
Julien Plu's avatar
Julien Plu committed
489

thomwolf's avatar
thomwolf committed
490
491
492
493
494
        return prediction_scores


class TFBertNSPHead(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
495
        super().__init__(**kwargs)
496

497
498
499
        self.seq_relationship = tf.keras.layers.Dense(
            2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
        )
thomwolf's avatar
thomwolf committed
500
501
502

    def call(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
Julien Plu's avatar
Julien Plu committed
503

thomwolf's avatar
thomwolf committed
504
505
506
        return seq_relationship_score


507
508
@keras_serializable
class TFBertMainLayer(tf.keras.layers.Layer):
509
510
    config_class = BertConfig

thomwolf's avatar
thomwolf committed
511
    def __init__(self, config, **kwargs):
512
        super().__init__(**kwargs)
513

thomwolf's avatar
thomwolf committed
514
        self.num_hidden_layers = config.num_hidden_layers
515
        self.initializer_range = config.initializer_range
516
        self.output_attentions = config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
517
        self.output_hidden_states = config.output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
518
        self.return_dict = config.use_return_dict
519
520
521
        self.embeddings = TFBertEmbeddings(config, name="embeddings")
        self.encoder = TFBertEncoder(config, name="encoder")
        self.pooler = TFBertPooler(config, name="pooler")
thomwolf's avatar
thomwolf committed
522

523
524
525
    def get_input_embeddings(self):
        return self.embeddings

526
527
528
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
        self.embeddings.vocab_size = value.shape[0]
thomwolf's avatar
thomwolf committed
529
530

    def _prune_heads(self, heads_to_prune):
Lysandre's avatar
Lysandre committed
531
532
533
        """Prunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        See base class PreTrainedModel
thomwolf's avatar
thomwolf committed
534
535
536
        """
        raise NotImplementedError

537
538
539
540
541
542
543
544
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
545
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
546
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
547
        return_dict=None,
548
549
        training=False,
    ):
thomwolf's avatar
thomwolf committed
550
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
551
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
552
553
554
555
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
556
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
557
            output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
Joseph Liu's avatar
Joseph Liu committed
558
            output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
559
560
            return_dict = inputs[8] if len(inputs) > 8 else return_dict
            assert len(inputs) <= 9, "Too many inputs."
561
        elif isinstance(inputs, (dict, BatchEncoding)):
562
563
564
565
566
567
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
568
            output_attentions = inputs.get("output_attentions", output_attentions)
Joseph Liu's avatar
Joseph Liu committed
569
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
Sylvain Gugger's avatar
Sylvain Gugger committed
570
571
            return_dict = inputs.get("return_dict", return_dict)
            assert len(inputs) <= 9, "Too many inputs."
thomwolf's avatar
thomwolf committed
572
573
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
574

575
        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
Joseph Liu's avatar
Joseph Liu committed
576
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
577
        return_dict = return_dict if return_dict is not None else self.return_dict
578

579
580
581
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
582
            input_shape = shape_list(input_ids)
583
        elif inputs_embeds is not None:
584
            input_shape = shape_list(inputs_embeds)[:-1]
585
586
587
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

thomwolf's avatar
thomwolf committed
588
        if attention_mask is None:
589
            attention_mask = tf.fill(input_shape, 1)
590

thomwolf's avatar
thomwolf committed
591
        if token_type_ids is None:
592
            token_type_ids = tf.fill(input_shape, 0)
thomwolf's avatar
thomwolf committed
593

594
595
        embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)

thomwolf's avatar
thomwolf committed
596
597
598
599
600
601
602
603
604
605
606
607
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
608
        extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
thomwolf's avatar
thomwolf committed
609
610
611
612
613
614
615
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
616
        if head_mask is not None:
thomwolf's avatar
thomwolf committed
617
618
619
620
621
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)

622
        encoder_outputs = self.encoder(
Julien Plu's avatar
Julien Plu committed
623
624
625
626
627
            embedding_output,
            extended_attention_mask,
            head_mask,
            output_attentions,
            output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
628
            return_dict,
Joseph Liu's avatar
Joseph Liu committed
629
            training=training,
630
        )
thomwolf's avatar
thomwolf committed
631
632
633

        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)
Julien Plu's avatar
Julien Plu committed
634

Sylvain Gugger's avatar
Sylvain Gugger committed
635
        if not return_dict:
Lysandre's avatar
Lysandre committed
636
637
638
639
            return (
                sequence_output,
                pooled_output,
            ) + encoder_outputs[1:]
Sylvain Gugger's avatar
Sylvain Gugger committed
640
641
642
643
644
645
646

        return TFBaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
647

thomwolf's avatar
thomwolf committed
648

thomwolf's avatar
thomwolf committed
649
class TFBertPreTrainedModel(TFPreTrainedModel):
Lysandre's avatar
Lysandre committed
650
651
    """An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
652
    """
653

thomwolf's avatar
thomwolf committed
654
655
656
657
    config_class = BertConfig
    base_model_prefix = "bert"


Sylvain Gugger's avatar
Sylvain Gugger committed
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
@dataclass
class TFBertForPreTrainingOutput(ModelOutput):
    """
    Output type of :class:`~transformers.TFBertForPreTrainingModel`.

    Args:
        prediction_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False
            continuation before SoftMax).
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    prediction_logits: tf.Tensor = None
    seq_relationship_logits: tf.Tensor = None
    hidden_states: Optional[Tuple[tf.Tensor]] = None
    attentions: Optional[Tuple[tf.Tensor]] = None


Lysandre's avatar
Lysandre committed
688
BERT_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
689
690
691
692
693
694
695
696

    This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
    generic methods the library implements for all its model (such as downloading or saving, resizing the input
    embeddings, pruning heads etc.)

    This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass.
    Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general
    usage and behavior.
thomwolf's avatar
thomwolf committed
697

Lysandre's avatar
Lysandre committed
698
    .. note::
Lysandre's avatar
Lysandre committed
699

thomwolf's avatar
thomwolf committed
700
701
        TF 2.0 models accepts two formats as inputs:

Sylvain Gugger's avatar
Sylvain Gugger committed
702
703
        - having all inputs as keyword arguments (like PyTorch models), or
        - having all inputs as a list, tuple or dict in the first positional arguments.
thomwolf's avatar
thomwolf committed
704

Sylvain Gugger's avatar
Sylvain Gugger committed
705
        This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having
Lysandre's avatar
Lysandre committed
706
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
thomwolf's avatar
thomwolf committed
707

Lysandre's avatar
Lysandre committed
708
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
Lysandre's avatar
Lysandre committed
709
        in the first positional argument :
thomwolf's avatar
thomwolf committed
710

Sylvain Gugger's avatar
Sylvain Gugger committed
711
        - a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)`
thomwolf's avatar
thomwolf committed
712
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
Lysandre's avatar
Lysandre committed
713
714
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
Sylvain Gugger's avatar
Sylvain Gugger committed
715
          :obj:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
thomwolf's avatar
thomwolf committed
716

Sylvain Gugger's avatar
Sylvain Gugger committed
717
    Args:
718
        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
thomwolf's avatar
thomwolf committed
719
            Initializing with a config file does not load the weights associated with the model, only the configuration.
720
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
721
722
723
"""

BERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
724
    Args:
Sylvain Gugger's avatar
Sylvain Gugger committed
725
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`):
Lysandre's avatar
Lysandre committed
726
727
            Indices of input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
728
729
730
            Indices can be obtained using :class:`~transformers.BertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.__call__` and
            :func:`transformers.PreTrainedTokenizer.encode` for details.
Lysandre's avatar
Lysandre committed
731

Lysandre's avatar
Lysandre committed
732
            `What are input IDs? <../glossary.html#input-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
733
        attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
thomwolf's avatar
thomwolf committed
734
735
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
736
737

            - 1 for tokens that are **not masked**,
738
            - 0 for tokens that are **masked**.
Lysandre's avatar
Lysandre committed
739

Lysandre's avatar
Lysandre committed
740
            `What are attention masks? <../glossary.html#attention-mask>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
741
        token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
thomwolf's avatar
thomwolf committed
742
            Segment token indices to indicate first and second portions of the inputs.
Sylvain Gugger's avatar
Sylvain Gugger committed
743
744
745
746
            Indices are selected in ``[0, 1]``:

            - 0 corresponds to a `sentence A` token,
            - 1 corresponds to a `sentence B` token.
Lysandre's avatar
Lysandre committed
747
748

            `What are token type IDs? <../glossary.html#token-type-ids>`__
Sylvain Gugger's avatar
Sylvain Gugger committed
749
        position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
thomwolf's avatar
thomwolf committed
750
751
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
752

Lysandre's avatar
Lysandre committed
753
            `What are position IDs? <../glossary.html#position-ids>`__
754
        head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
thomwolf's avatar
thomwolf committed
755
756
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
757
758
759
760
761

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Lysandre's avatar
Lysandre committed
762
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
763
764
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
765
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
766
767
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
768
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
769
770
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
771
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
772
773
774
775
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
        training (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not to use the model in training mode (some modules like dropout modules have different
            behaviors between training and evaluation).
thomwolf's avatar
thomwolf committed
776
777
"""

778
779
780
781
782

@add_start_docstrings(
    "The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
783
class TFBertModel(TFBertPreTrainedModel):
thomwolf's avatar
thomwolf committed
784
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
785
        super().__init__(config, *inputs, **kwargs)
786

787
        self.bert = TFBertMainLayer(config, name="bert")
thomwolf's avatar
thomwolf committed
788

Sylvain Gugger's avatar
Sylvain Gugger committed
789
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
790
791
792
793
794
795
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFBaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
thomwolf's avatar
thomwolf committed
796
797
    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
798

thomwolf's avatar
thomwolf committed
799
800
801
        return outputs


802
803
@add_start_docstrings(
    """Bert Model with two heads on top as done during the pre-training:
thomwolf's avatar
thomwolf committed
804
    a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
805
806
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
807
class TFBertForPreTraining(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
808
809
810
811
812
813
814
815
816
817
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.nsp = TFBertNSPHead(config, name="nsp___cls")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

Sylvain Gugger's avatar
Sylvain Gugger committed
818
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
819
    @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Lysandre's avatar
Lysandre committed
820
821
    def call(self, inputs, **kwargs):
        r"""
Lysandre's avatar
Lysandre committed
822
        Return:
thomwolf's avatar
thomwolf committed
823

Lysandre's avatar
Lysandre committed
824
        Examples::
thomwolf's avatar
thomwolf committed
825

826
827
            >>> import tensorflow as tf
            >>> from transformers import BertTokenizer, TFBertForPreTraining
thomwolf's avatar
thomwolf committed
828

829
830
831
832
833
            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            >>> model = TFBertForPreTraining.from_pretrained('bert-base-uncased')
            >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
            >>> outputs = model(input_ids)
            >>> prediction_scores, seq_relationship_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
834

Lysandre's avatar
Lysandre committed
835
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
836
837
        return_dict = kwargs.get("return_dict")
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
thomwolf's avatar
thomwolf committed
838
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
839
        sequence_output, pooled_output = outputs[:2]
840
        prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
841
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
842

Sylvain Gugger's avatar
Sylvain Gugger committed
843
844
        if not return_dict:
            return (prediction_scores, seq_relationship_score) + outputs[2:]
thomwolf's avatar
thomwolf committed
845

Sylvain Gugger's avatar
Sylvain Gugger committed
846
847
848
849
850
851
        return TFBertForPreTrainingOutput(
            prediction_logits=prediction_scores,
            seq_relationship_logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
852
853


Lysandre's avatar
Lysandre committed
854
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
855
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
856

Julien Plu's avatar
Julien Plu committed
857
    authorized_unexpected_keys = [r"pooler"]
858
859
    authorized_missing_keys = [r"pooler"]

Lysandre's avatar
Lysandre committed
860
861
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
Lysandre Debut's avatar
Lysandre Debut committed
862
863

        if config.is_decoder:
864
            logger.warning(
Lysandre Debut's avatar
Lysandre Debut committed
865
866
867
                "If you want to use `TFBertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )
Lysandre's avatar
Lysandre committed
868
869
870
871
872
873
874

        self.bert = TFBertMainLayer(config, name="bert")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

Sylvain Gugger's avatar
Sylvain Gugger committed
875
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
876
877
878
879
880
881
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFMaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
882
883
884
885
886
887
888
889
890
891
    def call(
        self,
        inputs=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
892
        return_dict=None,
893
894
895
        labels=None,
        training=False,
    ):
Lysandre's avatar
Lysandre committed
896
        r"""
897
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
898
899
900
901
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
902
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
903
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
904

905
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
906
907
908
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
909
910
911
912
913
914
915
916
917
918
919
920
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

        outputs = self.bert(
            inputs,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
921
            return_dict=return_dict,
922
923
            training=training,
        )
thomwolf's avatar
thomwolf committed
924
925

        sequence_output = outputs[0]
926
        prediction_scores = self.mlm(sequence_output, training=training)
Sylvain Gugger's avatar
Sylvain Gugger committed
927
        loss = None if labels is None else self.compute_loss(labels, prediction_scores)
thomwolf's avatar
thomwolf committed
928

Sylvain Gugger's avatar
Sylvain Gugger committed
929
930
931
        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
932

Sylvain Gugger's avatar
Sylvain Gugger committed
933
        return TFMaskedLMOutput(
Lysandre's avatar
Lysandre committed
934
935
936
937
            loss=loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
938
        )
939
940
941


class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
942

Julien Plu's avatar
Julien Plu committed
943
    authorized_unexpected_keys = [r"pooler"]
944
945
    authorized_missing_keys = [r"pooler"]

946
947
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
Lysandre Debut's avatar
Lysandre Debut committed
948
949

        if not config.is_decoder:
950
            logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")
951
952
953
954
955
956
957

        self.bert = TFBertMainLayer(config, name="bert")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

Sylvain Gugger's avatar
Sylvain Gugger committed
958
959
960
961
962
963
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFCausalLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
964
965
966
967
968
969
970
971
972
973
    def call(
        self,
        inputs=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
974
        return_dict=None,
975
976
977
978
        labels=None,
        training=False,
    ):
        r"""
979
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
980
981
982
            Labels for computing the cross entropy classification loss.
            Indices should be in ``[0, ..., config.vocab_size - 1]``.
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
983
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
984

985
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
986
987
988
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
989
990
991
992
993
994
995
996
997
998
999
1000
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

        outputs = self.bert(
            inputs,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1001
            return_dict=return_dict,
1002
1003
1004
1005
1006
            training=training,
        )

        sequence_output = outputs[0]
        logits = self.mlm(sequence_output, training=training)
Sylvain Gugger's avatar
Sylvain Gugger committed
1007
        loss = None
1008

1009
1010
1011
1012
1013
1014
        if labels is not None:
            # shift labels to the left and cut last logit token
            logits = logits[:, :-1]
            labels = labels[:, 1:]
            loss = self.compute_loss(labels, logits)

Sylvain Gugger's avatar
Sylvain Gugger committed
1015
1016
1017
1018
1019
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFCausalLMOutput(
Lysandre's avatar
Lysandre committed
1020
1021
1022
1023
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1024
        )
thomwolf's avatar
thomwolf committed
1025
1026


1027
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1028
1029
    """Bert Model with a `next sentence prediction (classification)` head on top. """,
    BERT_START_DOCSTRING,
1030
)
thomwolf's avatar
thomwolf committed
1031
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
1032
1033
1034
1035
1036
1037
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.nsp = TFBertNSPHead(config, name="nsp___cls")

Sylvain Gugger's avatar
Sylvain Gugger committed
1038
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1039
    @replace_return_docstrings(output_type=TFNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
Lysandre's avatar
Lysandre committed
1040
1041
    def call(self, inputs, **kwargs):
        r"""
Lysandre's avatar
Lysandre committed
1042
        Return:
thomwolf's avatar
thomwolf committed
1043

Lysandre's avatar
Lysandre committed
1044
        Examples::
thomwolf's avatar
thomwolf committed
1045

1046
1047
            >>> import tensorflow as tf
            >>> from transformers import BertTokenizer, TFBertForNextSentencePrediction
thomwolf's avatar
thomwolf committed
1048

1049
1050
            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            >>> model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
1051

1052
1053
1054
            >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
            >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
            >>> encoding = tokenizer(prompt, next_sentence, return_tensors='tf')
1055

1056
1057
            >>> logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
            >>> assert logits[0][0] < logits[0][1] # the next sentence was random
Lysandre's avatar
Lysandre committed
1058
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1059
1060
        return_dict = kwargs.get("return_dict")
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
thomwolf's avatar
thomwolf committed
1061
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
1062
        pooled_output = outputs[1]
1063
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
1064

Sylvain Gugger's avatar
Sylvain Gugger committed
1065
1066
        if not return_dict:
            return (seq_relationship_score,) + outputs[2:]
thomwolf's avatar
thomwolf committed
1067

Sylvain Gugger's avatar
Sylvain Gugger committed
1068
        return TFNextSentencePredictorOutput(
Lysandre's avatar
Lysandre committed
1069
1070
1071
            logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1072
        )
thomwolf's avatar
thomwolf committed
1073
1074


1075
1076
@add_start_docstrings(
    """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
1077
    the pooled output) e.g. for GLUE tasks. """,
1078
1079
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1080
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
Lysandre's avatar
Lysandre committed
1081
1082
1083
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

1084
        self.num_labels = config.num_labels
Lysandre's avatar
Lysandre committed
1085
1086
1087
1088
1089
1090
        self.bert = TFBertMainLayer(config, name="bert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
1091
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1092
1093
1094
1095
1096
1097
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFSequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Julien Plu's avatar
Julien Plu committed
1098
1099
    def call(
        self,
1100
        inputs=None,
Julien Plu's avatar
Julien Plu committed
1101
1102
1103
1104
1105
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
1106
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1107
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1108
        return_dict=None,
1109
        labels=None,
Julien Plu's avatar
Julien Plu committed
1110
1111
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1112
        r"""
1113
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Julien Plu's avatar
Julien Plu committed
1114
1115
1116
1117
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Lysandre's avatar
Lysandre committed
1118
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1119
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
1120

1121
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1122
1123
1124
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
1125
1126
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)
Julien Plu's avatar
Julien Plu committed
1127
1128

        outputs = self.bert(
1129
            inputs,
Julien Plu's avatar
Julien Plu committed
1130
1131
1132
1133
1134
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1135
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1136
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1137
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
1138
1139
            training=training,
        )
thomwolf's avatar
thomwolf committed
1140
1141

        pooled_output = outputs[1]
Julien Plu's avatar
Julien Plu committed
1142
        pooled_output = self.dropout(pooled_output, training=training)
thomwolf's avatar
thomwolf committed
1143
        logits = self.classifier(pooled_output)
Sylvain Gugger's avatar
Sylvain Gugger committed
1144
        loss = None if labels is None else self.compute_loss(labels, logits)
thomwolf's avatar
thomwolf committed
1145

Sylvain Gugger's avatar
Sylvain Gugger committed
1146
1147
1148
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
Julien Plu's avatar
Julien Plu committed
1149

Sylvain Gugger's avatar
Sylvain Gugger committed
1150
        return TFSequenceClassifierOutput(
Lysandre's avatar
Lysandre committed
1151
1152
1153
1154
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1155
        )
thomwolf's avatar
thomwolf committed
1156
1157


1158
1159
@add_start_docstrings(
    """Bert Model with a multiple choice classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
1160
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1161
1162
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1163
class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
thomwolf's avatar
thomwolf committed
1164
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
1165
        super().__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
1166

1167
        self.bert = TFBertMainLayer(config, name="bert")
thomwolf's avatar
thomwolf committed
1168
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
1169
1170
1171
1172
        self.classifier = tf.keras.layers.Dense(
            1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Lysandre's avatar
Lysandre committed
1173
1174
    @property
    def dummy_inputs(self):
Lysandre's avatar
Lysandre committed
1175
        """Dummy inputs to build the network.
Lysandre's avatar
Lysandre committed
1176
1177
1178
1179
1180
1181

        Returns:
            tf.Tensor with dummy inputs
        """
        return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}

Sylvain Gugger's avatar
Sylvain Gugger committed
1182
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1183
1184
1185
1186
1187
1188
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFMultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1189
1190
1191
1192
1193
1194
1195
1196
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
1197
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1198
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1199
        return_dict=None,
1200
        labels=None,
1201
1202
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1203
        r"""
1204
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Julien Plu's avatar
Julien Plu committed
1205
            Labels for computing the multiple choice classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1206
1207
            Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension
            of the input tensors. (See :obj:`input_ids` above)
Lysandre's avatar
Lysandre committed
1208
        """
thomwolf's avatar
thomwolf committed
1209
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
1210
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
1211
1212
1213
1214
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
1215
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
1216
            output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
1217
            output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
Sylvain Gugger's avatar
Sylvain Gugger committed
1218
1219
1220
            return_dict = inputs[8] if len(inputs) > 8 else return_dict
            labels = inputs[9] if len(inputs) > 9 else labels
            assert len(inputs) <= 10, "Too many inputs."
Julien Plu's avatar
Julien Plu committed
1221
        elif isinstance(inputs, (dict, BatchEncoding)):
1222
1223
1224
1225
1226
1227
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
1228
            output_attentions = inputs.get("output_attentions", output_attentions)
1229
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
Sylvain Gugger's avatar
Sylvain Gugger committed
1230
            return_dict = inputs.get("return_dict", return_dict)
1231
            labels = inputs.get("labels", labels)
Sylvain Gugger's avatar
Sylvain Gugger committed
1232
            assert len(inputs) <= 10, "Too many inputs."
thomwolf's avatar
thomwolf committed
1233
1234
        else:
            input_ids = inputs
1235

Sylvain Gugger's avatar
Sylvain Gugger committed
1236
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
thomwolf's avatar
thomwolf committed
1237

1238
        if input_ids is not None:
1239
1240
            num_choices = shape_list(input_ids)[1]
            seq_length = shape_list(input_ids)[2]
1241
        else:
1242
1243
            num_choices = shape_list(inputs_embeds)[1]
            seq_length = shape_list(inputs_embeds)[2]
thomwolf's avatar
thomwolf committed
1244

1245
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
thomwolf's avatar
thomwolf committed
1246
1247
1248
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
1249
1250
1251
1252
1253
        flat_inputs_embeds = (
            tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
            if inputs_embeds is not None
            else None
        )
Julien Plu's avatar
Julien Plu committed
1254
        outputs = self.bert(
1255
1256
1257
1258
1259
            flat_input_ids,
            flat_attention_mask,
            flat_token_type_ids,
            flat_position_ids,
            head_mask,
1260
            flat_inputs_embeds,
1261
            output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1262
            output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1263
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
1264
1265
            training=training,
        )
thomwolf's avatar
thomwolf committed
1266
        pooled_output = outputs[1]
thomwolf's avatar
thomwolf committed
1267
        pooled_output = self.dropout(pooled_output, training=training)
thomwolf's avatar
thomwolf committed
1268
1269
        logits = self.classifier(pooled_output)
        reshaped_logits = tf.reshape(logits, (-1, num_choices))
Sylvain Gugger's avatar
Sylvain Gugger committed
1270
        loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
Julien Plu's avatar
Julien Plu committed
1271

Sylvain Gugger's avatar
Sylvain Gugger committed
1272
1273
1274
1275
1276
        if not return_dict:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFMultipleChoiceModelOutput(
Lysandre's avatar
Lysandre committed
1277
1278
1279
1280
            loss=loss,
            logits=reshaped_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1281
        )
thomwolf's avatar
thomwolf committed
1282
1283


1284
1285
@add_start_docstrings(
    """Bert Model with a token classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
1286
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1287
1288
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1289
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
1290

Julien Plu's avatar
Julien Plu committed
1291
    authorized_unexpected_keys = [r"pooler"]
1292
1293
    authorized_missing_keys = [r"pooler"]

Lysandre's avatar
Lysandre committed
1294
1295
1296
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

1297
        self.num_labels = config.num_labels
Lysandre's avatar
Lysandre committed
1298
1299
1300
1301
1302
1303
        self.bert = TFBertMainLayer(config, name="bert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
1304
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1305
1306
1307
1308
1309
1310
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFTokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Julien Plu's avatar
Julien Plu committed
1311
1312
    def call(
        self,
1313
        inputs=None,
Julien Plu's avatar
Julien Plu committed
1314
1315
1316
1317
1318
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
1319
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1320
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1321
        return_dict=None,
1322
        labels=None,
Julien Plu's avatar
Julien Plu committed
1323
1324
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1325
        r"""
1326
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Julien Plu's avatar
Julien Plu committed
1327
1328
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.
Lysandre's avatar
Lysandre committed
1329
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1330
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
1331

1332
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1333
1334
1335
            labels = inputs[9] if len(inputs) > 9 else labels
            if len(inputs) > 9:
                inputs = inputs[:9]
1336
1337
1338
        elif isinstance(inputs, (dict, BatchEncoding)):
            labels = inputs.pop("labels", labels)

Julien Plu's avatar
Julien Plu committed
1339
        outputs = self.bert(
1340
            inputs,
Julien Plu's avatar
Julien Plu committed
1341
1342
1343
1344
1345
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1346
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1347
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1348
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
1349
1350
            training=training,
        )
thomwolf's avatar
thomwolf committed
1351
        sequence_output = outputs[0]
Julien Plu's avatar
Julien Plu committed
1352
        sequence_output = self.dropout(sequence_output, training=training)
thomwolf's avatar
thomwolf committed
1353
        logits = self.classifier(sequence_output)
Sylvain Gugger's avatar
Sylvain Gugger committed
1354
        loss = None if labels is None else self.compute_loss(labels, logits)
thomwolf's avatar
thomwolf committed
1355

Sylvain Gugger's avatar
Sylvain Gugger committed
1356
1357
1358
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
Julien Plu's avatar
Julien Plu committed
1359

Sylvain Gugger's avatar
Sylvain Gugger committed
1360
        return TFTokenClassifierOutput(
Lysandre's avatar
Lysandre committed
1361
1362
1363
1364
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
Sylvain Gugger's avatar
Sylvain Gugger committed
1365
        )
thomwolf's avatar
thomwolf committed
1366
1367


1368
@add_start_docstrings(
Sylvain Gugger's avatar
Sylvain Gugger committed
1369
1370
    """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layer on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
1371
1372
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1373
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
1374

Julien Plu's avatar
Julien Plu committed
1375
    authorized_unexpected_keys = [r"pooler"]
1376
1377
    authorized_missing_keys = [r"pooler"]

Lysandre's avatar
Lysandre committed
1378
1379
1380
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

1381
        self.num_labels = config.num_labels
Lysandre's avatar
Lysandre committed
1382
1383
1384
1385
1386
        self.bert = TFBertMainLayer(config, name="bert")
        self.qa_outputs = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
        )

Sylvain Gugger's avatar
Sylvain Gugger committed
1387
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Sylvain Gugger's avatar
Sylvain Gugger committed
1388
1389
1390
1391
1392
1393
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-cased",
        output_type=TFQuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Julien Plu's avatar
Julien Plu committed
1394
1395
    def call(
        self,
1396
        inputs=None,
Julien Plu's avatar
Julien Plu committed
1397
1398
1399
1400
1401
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
1402
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1403
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
1404
        return_dict=None,
1405
1406
        start_positions=None,
        end_positions=None,
Julien Plu's avatar
Julien Plu committed
1407
1408
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1409
        r"""
1410
        start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Julien Plu's avatar
Julien Plu committed
1411
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1412
            Positions are clamped to the length of the sequence (:obj:`sequence_length`).
Julien Plu's avatar
Julien Plu committed
1413
            Position outside of the sequence are not taken into account for computing the loss.
1414
        end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Julien Plu's avatar
Julien Plu committed
1415
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
1416
            Positions are clamped to the length of the sequence (:obj:`sequence_length`).
Julien Plu's avatar
Julien Plu committed
1417
            Position outside of the sequence are not taken into account for computing the loss.
Lysandre's avatar
Lysandre committed
1418
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1419
        return_dict = return_dict if return_dict is not None else self.bert.return_dict
1420

1421
        if isinstance(inputs, (tuple, list)):
Sylvain Gugger's avatar
Sylvain Gugger committed
1422
1423
1424
1425
            start_positions = inputs[9] if len(inputs) > 9 else start_positions
            end_positions = inputs[10] if len(inputs) > 10 else end_positions
            if len(inputs) > 9:
                inputs = inputs[:9]
1426
1427
1428
1429
        elif isinstance(inputs, (dict, BatchEncoding)):
            start_positions = inputs.pop("start_positions", start_positions)
            end_positions = inputs.pop("end_positions", start_positions)

Julien Plu's avatar
Julien Plu committed
1430
        outputs = self.bert(
1431
            inputs,
Julien Plu's avatar
Julien Plu committed
1432
1433
1434
1435
1436
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1437
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1438
            output_hidden_states=output_hidden_states,
Sylvain Gugger's avatar
Sylvain Gugger committed
1439
            return_dict=return_dict,
Julien Plu's avatar
Julien Plu committed
1440
1441
            training=training,
        )
thomwolf's avatar
thomwolf committed
1442
1443
1444
1445
1446
        sequence_output = outputs[0]
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)
Sylvain Gugger's avatar
Sylvain Gugger committed
1447
        loss = None
1448

Julien Plu's avatar
Julien Plu committed
1449
1450
1451
        if start_positions is not None and end_positions is not None:
            labels = {"start_position": start_positions}
            labels["end_position"] = end_positions
Sylvain Gugger's avatar
Sylvain Gugger committed
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
            loss = self.compute_loss(labels, (start_logits, end_logits))

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TFQuestionAnsweringModelOutput(
            loss=loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )