modeling_tf_bert.py 58.2 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 BERT model. """


import logging

import numpy as np
import tensorflow as tf

from .configuration_bert import BertConfig
Lysandre's avatar
Lysandre committed
25
from .file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
Julien Plu's avatar
Julien Plu committed
26
27
28
29
30
31
32
33
34
35
from .modeling_tf_utils import (
    TFMultipleChoiceLoss,
    TFPreTrainedModel,
    TFQuestionAnsweringLoss,
    TFSequenceClassificationLoss,
    TFTokenClassificationLoss,
    get_initializer,
    keras_serializable,
    shape_list,
)
36
from .tokenization_utils import BatchEncoding
Aymeric Augustin's avatar
Aymeric Augustin committed
37

thomwolf's avatar
thomwolf committed
38
39
40
41

logger = logging.getLogger(__name__)


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "bert-base-uncased",
    "bert-large-uncased",
    "bert-base-cased",
    "bert-large-cased",
    "bert-base-multilingual-uncased",
    "bert-base-multilingual-cased",
    "bert-base-chinese",
    "bert-base-german-cased",
    "bert-large-uncased-whole-word-masking",
    "bert-large-cased-whole-word-masking",
    "bert-large-uncased-whole-word-masking-finetuned-squad",
    "bert-large-cased-whole-word-masking-finetuned-squad",
    "bert-base-cased-finetuned-mrpc",
    "cl-tohoku/bert-base-japanese",
    "cl-tohoku/bert-base-japanese-whole-word-masking",
    "cl-tohoku/bert-base-japanese-char",
    "cl-tohoku/bert-base-japanese-char-whole-word-masking",
    "TurkuNLP/bert-base-finnish-cased-v1",
    "TurkuNLP/bert-base-finnish-uncased-v1",
    "wietsedv/bert-base-dutch-cased",
    # See all BERT models at https://huggingface.co/models?filter=bert
]
thomwolf's avatar
thomwolf committed
65
66
67


def gelu(x):
thomwolf's avatar
thomwolf committed
68
    """ Gaussian Error Linear Unit.
Santiago Castro's avatar
Santiago Castro committed
69
    Original Implementation of the gelu activation function in Google Bert repo when initially created.
thomwolf's avatar
thomwolf committed
70
71
72
73
74
75
76
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
    return x * cdf

77

thomwolf's avatar
thomwolf committed
78
def gelu_new(x):
thomwolf's avatar
thomwolf committed
79
80
81
82
83
84
85
86
    """Gaussian Error Linear Unit.
    This is a smoother version of the RELU.
    Original paper: https://arxiv.org/abs/1606.08415
    Args:
        x: float Tensor to perform activation.
    Returns:
        `x` with the GELU activation applied.
    """
87
    cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
thomwolf's avatar
thomwolf committed
88
89
    return x * cdf

90

thomwolf's avatar
thomwolf committed
91
92
93
94
def swish(x):
    return x * tf.sigmoid(x)


95
96
97
98
99
100
ACT2FN = {
    "gelu": tf.keras.layers.Activation(gelu),
    "relu": tf.keras.activations.relu,
    "swish": tf.keras.layers.Activation(swish),
    "gelu_new": tf.keras.layers.Activation(gelu_new),
}
thomwolf's avatar
thomwolf committed
101
102
103
104
105


class TFBertEmbeddings(tf.keras.layers.Layer):
    """Construct the embeddings from word, position and token_type embeddings.
    """
106

thomwolf's avatar
thomwolf committed
107
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
108
        super().__init__(**kwargs)
109
110
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size
thomwolf's avatar
thomwolf committed
111
        self.initializer_range = config.initializer_range
112

113
114
115
116
117
118
119
120
121
122
123
124
        self.position_embeddings = tf.keras.layers.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="position_embeddings",
        )
        self.token_type_embeddings = tf.keras.layers.Embedding(
            config.type_vocab_size,
            config.hidden_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="token_type_embeddings",
        )
thomwolf's avatar
thomwolf committed
125
126
127

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
128
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
129
130
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

131
132
133
134
135
136
137
138
    def build(self, input_shape):
        """Build shared word embedding layer """
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight",
                shape=[self.vocab_size, self.hidden_size],
139
140
                initializer=get_initializer(self.initializer_range),
            )
Julien Chaumond's avatar
Julien Chaumond committed
141
        super().build(input_shape)
142
143
144
145
146
147
148
149
150
151
152
153

    def call(self, inputs, mode="embedding", training=False):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.
154

155
156
157
158
159
160
161
162
163
164
165
166
        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(inputs, training=training)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    def _embedding(self, inputs, training=False):
        """Applies embedding based on inputs tensor."""
167
        input_ids, position_ids, token_type_ids, inputs_embeds = inputs
thomwolf's avatar
thomwolf committed
168

169
        if input_ids is not None:
170
            input_shape = shape_list(input_ids)
171
        else:
172
            input_shape = shape_list(inputs_embeds)[:-1]
173

174
        seq_length = input_shape[1]
thomwolf's avatar
thomwolf committed
175
176
177
        if position_ids is None:
            position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
        if token_type_ids is None:
178
            token_type_ids = tf.fill(input_shape, 0)
thomwolf's avatar
thomwolf committed
179

180
181
        if inputs_embeds is None:
            inputs_embeds = tf.gather(self.word_embeddings, input_ids)
thomwolf's avatar
thomwolf committed
182
183
184
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

185
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
thomwolf's avatar
thomwolf committed
186
        embeddings = self.LayerNorm(embeddings)
thomwolf's avatar
thomwolf committed
187
        embeddings = self.dropout(embeddings, training=training)
thomwolf's avatar
thomwolf committed
188
189
        return embeddings

190
191
192
193
194
195
196
    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
            Args:
                inputs: A float32 tensor with shape [batch_size, length, hidden_size]
            Returns:
                float32 tensor with shape [batch_size, length, vocab_size].
        """
197
198
        batch_size = shape_list(inputs)[0]
        length = shape_list(inputs)[1]
199
200
201
202
203
204

        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])

thomwolf's avatar
thomwolf committed
205
206
207

class TFBertSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
208
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
209
210
211
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
212
213
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
thomwolf's avatar
thomwolf committed
214
215
216
217
218
219
220
        self.output_attentions = config.output_attentions

        self.num_attention_heads = config.num_attention_heads
        assert config.hidden_size % config.num_attention_heads == 0
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

221
222
223
224
225
226
227
228
229
        self.query = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
thomwolf's avatar
thomwolf committed
230
231
232
233
234
235
236
237
238
239

        self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, inputs, training=False):
        hidden_states, attention_mask, head_mask = inputs

240
        batch_size = shape_list(hidden_states)[0]
thomwolf's avatar
thomwolf committed
241
242
243
244
245
246
247
248
249
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        # Take the dot product between "query" and "key" to get the raw attention scores.
250
251
252
253
        attention_scores = tf.matmul(
            query_layer, key_layer, transpose_b=True
        )  # (batch size, num_heads, seq_len_q, seq_len_k)
        dk = tf.cast(shape_list(key_layer)[-1], tf.float32)  # scale attention_scores
thomwolf's avatar
thomwolf committed
254
        attention_scores = attention_scores / tf.math.sqrt(dk)
thomwolf's avatar
thomwolf committed
255
256
257
258

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
            attention_scores = attention_scores + attention_mask
thomwolf's avatar
thomwolf committed
259
260
261
262

        # Normalize the attention scores to probabilities.
        attention_probs = tf.nn.softmax(attention_scores, axis=-1)

thomwolf's avatar
thomwolf committed
263
264
265
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs, training=training)
thomwolf's avatar
thomwolf committed
266
267
268
269
270
271
272
273

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = tf.matmul(attention_probs, value_layer)

        context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
274
275
276
        context_layer = tf.reshape(
            context_layer, (batch_size, -1, self.all_head_size)
        )  # (batch_size, seq_len_q, all_head_size)
thomwolf's avatar
thomwolf committed
277
278
279
280
281
282
283

        outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
        return outputs


class TFBertSelfOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
284
        super().__init__(**kwargs)
285
286
287
288
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
289
290
291
292
293
294
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def call(self, inputs, training=False):
        hidden_states, input_tensor = inputs

        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
295
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
296
297
298
299
300
301
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class TFBertAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
302
        super().__init__(**kwargs)
303
304
        self.self_attention = TFBertSelfAttention(config, name="self")
        self.dense_output = TFBertSelfOutput(config, name="output")
thomwolf's avatar
thomwolf committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

    def prune_heads(self, heads):
        raise NotImplementedError

    def call(self, inputs, training=False):
        input_tensor, attention_mask, head_mask = inputs

        self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training)
        attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


class TFBertIntermediate(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
320
        super().__init__(**kwargs)
321
322
323
        self.dense = tf.keras.layers.Dense(
            config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
324
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
325
326
327
328
329
330
331
332
333
334
335
336
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class TFBertOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
337
        super().__init__(**kwargs)
338
339
340
341
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
342
343
344
345
346
347
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def call(self, inputs, training=False):
        hidden_states, input_tensor = inputs

        hidden_states = self.dense(hidden_states)
thomwolf's avatar
thomwolf committed
348
        hidden_states = self.dropout(hidden_states, training=training)
thomwolf's avatar
thomwolf committed
349
350
351
352
353
354
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class TFBertLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
355
        super().__init__(**kwargs)
356
357
358
        self.attention = TFBertAttention(config, name="attention")
        self.intermediate = TFBertIntermediate(config, name="intermediate")
        self.bert_output = TFBertOutput(config, name="output")
thomwolf's avatar
thomwolf committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372

    def call(self, inputs, training=False):
        hidden_states, attention_mask, head_mask = inputs

        attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
        attention_output = attention_outputs[0]
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.bert_output([intermediate_output, attention_output], training=training)
        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them
        return outputs


class TFBertEncoder(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
373
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
374
375
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
376
        self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
thomwolf's avatar
thomwolf committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406

    def call(self, inputs, training=False):
        hidden_states, attention_mask, head_mask = inputs

        all_hidden_states = ()
        all_attentions = ()
        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training)
            hidden_states = layer_outputs[0]

            if self.output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # Add last layer
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            outputs = outputs + (all_attentions,)
        return outputs  # outputs, (hidden states), (attentions)


class TFBertPooler(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
407
        super().__init__(**kwargs)
408
409
410
411
412
413
        self.dense = tf.keras.layers.Dense(
            config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )
thomwolf's avatar
thomwolf committed
414
415
416
417
418
419
420
421
422
423
424

    def call(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        return pooled_output


class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
425
        super().__init__(**kwargs)
426
427
428
        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
429
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
430
431
432
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
433
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
thomwolf's avatar
thomwolf committed
434
435
436
437
438
439
440
441
442

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class TFBertLMPredictionHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
443
    def __init__(self, config, input_embeddings, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
444
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
445
        self.vocab_size = config.vocab_size
446
        self.transform = TFBertPredictionHeadTransform(config, name="transform")
thomwolf's avatar
thomwolf committed
447
448
449

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
thomwolf's avatar
thomwolf committed
450
        self.input_embeddings = input_embeddings
thomwolf's avatar
thomwolf committed
451
452

    def build(self, input_shape):
453
        self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
Julien Chaumond's avatar
Julien Chaumond committed
454
        super().build(input_shape)
thomwolf's avatar
thomwolf committed
455
456
457

    def call(self, hidden_states):
        hidden_states = self.transform(hidden_states)
thomwolf's avatar
thomwolf committed
458
459
        hidden_states = self.input_embeddings(hidden_states, mode="linear")
        hidden_states = hidden_states + self.bias
thomwolf's avatar
thomwolf committed
460
461
462
463
        return hidden_states


class TFBertMLMHead(tf.keras.layers.Layer):
thomwolf's avatar
thomwolf committed
464
    def __init__(self, config, input_embeddings, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
465
        super().__init__(**kwargs)
466
        self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
thomwolf's avatar
thomwolf committed
467
468
469
470
471
472
473
474

    def call(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class TFBertNSPHead(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
475
        super().__init__(**kwargs)
476
477
478
        self.seq_relationship = tf.keras.layers.Dense(
            2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
        )
thomwolf's avatar
thomwolf committed
479
480
481
482
483
484

    def call(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


485
486
@keras_serializable
class TFBertMainLayer(tf.keras.layers.Layer):
487
488
    config_class = BertConfig

thomwolf's avatar
thomwolf committed
489
    def __init__(self, config, **kwargs):
490
        super().__init__(**kwargs)
thomwolf's avatar
thomwolf committed
491
492
        self.num_hidden_layers = config.num_hidden_layers

493
494
495
        self.embeddings = TFBertEmbeddings(config, name="embeddings")
        self.encoder = TFBertEncoder(config, name="encoder")
        self.pooler = TFBertPooler(config, name="pooler")
thomwolf's avatar
thomwolf committed
496

497
498
499
    def get_input_embeddings(self):
        return self.embeddings

thomwolf's avatar
thomwolf committed
500
501
502
503
504
505
506
507
508
509
    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        raise NotImplementedError

510
511
512
513
514
515
516
517
518
519
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        training=False,
    ):
thomwolf's avatar
thomwolf committed
520
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
521
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
522
523
524
525
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
526
527
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
            assert len(inputs) <= 6, "Too many inputs."
528
        elif isinstance(inputs, (dict, BatchEncoding)):
529
530
531
532
533
534
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
535
            assert len(inputs) <= 6, "Too many inputs."
thomwolf's avatar
thomwolf committed
536
537
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
538

539
540
541
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
542
            input_shape = shape_list(input_ids)
543
        elif inputs_embeds is not None:
544
            input_shape = shape_list(inputs_embeds)[:-1]
545
546
547
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

thomwolf's avatar
thomwolf committed
548
        if attention_mask is None:
549
            attention_mask = tf.fill(input_shape, 1)
thomwolf's avatar
thomwolf committed
550
        if token_type_ids is None:
551
            token_type_ids = tf.fill(input_shape, 0)
thomwolf's avatar
thomwolf committed
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.

        extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
574
        if head_mask is not None:
thomwolf's avatar
thomwolf committed
575
576
577
578
579
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)

580
        embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
thomwolf's avatar
thomwolf committed
581
582
583
584
585
        encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)

        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

586
587
588
        outputs = (sequence_output, pooled_output,) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
thomwolf's avatar
thomwolf committed
589
590
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)

thomwolf's avatar
thomwolf committed
591

thomwolf's avatar
thomwolf committed
592
593
class TFBertPreTrainedModel(TFPreTrainedModel):
    """ An abstract class to handle weights initialization and
594
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
595
    """
596

thomwolf's avatar
thomwolf committed
597
598
599
600
    config_class = BertConfig
    base_model_prefix = "bert"


Lysandre's avatar
Lysandre committed
601
BERT_START_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
602
    This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
Lysandre's avatar
Lysandre committed
603
    Use it as a regular TF 2.0 Keras Model and
thomwolf's avatar
thomwolf committed
604
605
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.

Lysandre's avatar
Lysandre committed
606
    .. note::
Lysandre's avatar
Lysandre committed
607

thomwolf's avatar
thomwolf committed
608
609
610
611
612
        TF 2.0 models accepts two formats as inputs:

            - having all inputs as keyword arguments (like PyTorch models), or
            - having all inputs as a list, tuple or dict in the first positional arguments.

Lysandre's avatar
Lysandre committed
613
        This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
Lysandre's avatar
Lysandre committed
614
        all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
thomwolf's avatar
thomwolf committed
615

Lysandre's avatar
Lysandre committed
616
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors
Lysandre's avatar
Lysandre committed
617
        in the first positional argument :
thomwolf's avatar
thomwolf committed
618

Lysandre's avatar
Lysandre committed
619
        - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
thomwolf's avatar
thomwolf committed
620
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
Lysandre's avatar
Lysandre committed
621
622
623
          :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
        - a dictionary with one or several input Tensors associated to the input names given in the docstring:
          :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
thomwolf's avatar
thomwolf committed
624
625

    Parameters:
626
        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
thomwolf's avatar
thomwolf committed
627
            Initializing with a config file does not load the weights associated with the model, only the configuration.
628
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
629
630
631
"""

BERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
632
    Args:
633
        input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`):
Lysandre's avatar
Lysandre committed
634
635
            Indices of input sequence tokens in the vocabulary.

636
637
            Indices can be obtained using :class:`transformers.BertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
Lysandre's avatar
Lysandre committed
638
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
Lysandre's avatar
Lysandre committed
639

Lysandre's avatar
Lysandre committed
640
            `What are input IDs? <../glossary.html#input-ids>`__
641
        attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
642
643
644
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
645

Lysandre's avatar
Lysandre committed
646
            `What are attention masks? <../glossary.html#attention-mask>`__
647
        token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
648
649
650
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
Lysandre's avatar
Lysandre committed
651
652

            `What are token type IDs? <../glossary.html#token-type-ids>`__
653
        position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
654
655
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
656

Lysandre's avatar
Lysandre committed
657
658
            `What are position IDs? <../glossary.html#position-ids>`__
        head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
659
660
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
Lysandre committed
661
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
Lysandre's avatar
Lysandre committed
662
        inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
Lysandre committed
663
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
664
665
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
Lysandre's avatar
Lysandre committed
666
667
668
        training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
            Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
            (if set to :obj:`False`) for evaluation.
thomwolf's avatar
thomwolf committed
669
670
"""

671
672
673
674
675

@add_start_docstrings(
    "The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
676
class TFBertModel(TFBertPreTrainedModel):
thomwolf's avatar
thomwolf committed
677
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
678
        super().__init__(config, *inputs, **kwargs)
679
        self.bert = TFBertMainLayer(config, name="bert")
thomwolf's avatar
thomwolf committed
680

681
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
thomwolf's avatar
thomwolf committed
682
    def call(self, inputs, **kwargs):
Lysandre's avatar
Lysandre committed
683
        r"""
Lysandre Debut's avatar
Lysandre Debut committed
684
    Returns:
685
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre Debut's avatar
Lysandre Debut committed
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token)
            further processed by a Linear layer and a Tanh activation function. The Linear
            layer weights are trained from the next sentence prediction (classification)
            objective during Bert pretraining. This output is usually *not* a good summary
            of the semantic content of the input, you're often better with averaging or pooling
            the sequence of hidden-states for the whole input sequence.
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.


    Examples::

        import tensorflow as tf
        from transformers import BertTokenizer, TFBertModel

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertModel.from_pretrained('bert-base-uncased')
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
Lysandre's avatar
Lysandre committed
717
        """
thomwolf's avatar
thomwolf committed
718
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
719
720
721
        return outputs


722
723
@add_start_docstrings(
    """Bert Model with two heads on top as done during the pre-training:
thomwolf's avatar
thomwolf committed
724
    a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
725
726
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
727
class TFBertForPreTraining(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
728
729
730
731
732
733
734
735
736
737
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.nsp = TFBertNSPHead(config, name="nsp___cls")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

738
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Lysandre's avatar
Lysandre committed
739
740
741
    def call(self, inputs, **kwargs):
        r"""
    Return:
742
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
743
        prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
744
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Lysandre's avatar
Lysandre committed
745
        seq_relationship_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`):
thomwolf's avatar
thomwolf committed
746
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
Lysandre's avatar
Lysandre committed
747
748
749
750
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
751
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
752
753
754
755
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
756
757
758
759
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
760
        import tensorflow as tf
761
        from transformers import BertTokenizer, TFBertForPreTraining
thomwolf's avatar
thomwolf committed
762

thomwolf's avatar
thomwolf committed
763
764
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForPreTraining.from_pretrained('bert-base-uncased')
765
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
766
767
768
        outputs = model(input_ids)
        prediction_scores, seq_relationship_scores = outputs[:2]

Lysandre's avatar
Lysandre committed
769
        """
thomwolf's avatar
thomwolf committed
770
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
771
772

        sequence_output, pooled_output = outputs[:2]
773
        prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
774
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
775

776
777
778
        outputs = (prediction_scores, seq_relationship_score,) + outputs[
            2:
        ]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
779
780
781
782

        return outputs  # prediction_scores, seq_relationship_score, (hidden_states), (attentions)


Lysandre's avatar
Lysandre committed
783
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
thomwolf's avatar
thomwolf committed
784
class TFBertForMaskedLM(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
785
786
787
788
789
790
791
792
793
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

    def get_output_embeddings(self):
        return self.bert.embeddings

794
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Lysandre's avatar
Lysandre committed
795
796
797
    def call(self, inputs, **kwargs):
        r"""
    Return:
798
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
799
        prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
800
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Lysandre's avatar
Lysandre committed
801
802
803
804
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
805
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
806
807
808
809
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
810
811
812
813
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
814
        import tensorflow as tf
815
        from transformers import BertTokenizer, TFBertForMaskedLM
thomwolf's avatar
thomwolf committed
816

thomwolf's avatar
thomwolf committed
817
818
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForMaskedLM.from_pretrained('bert-base-uncased')
819
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
thomwolf's avatar
thomwolf committed
820
        outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
821
        prediction_scores = outputs[0]
thomwolf's avatar
thomwolf committed
822

Lysandre's avatar
Lysandre committed
823
        """
thomwolf's avatar
thomwolf committed
824
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
825
826

        sequence_output = outputs[0]
827
        prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
thomwolf's avatar
thomwolf committed
828
829
830
831
832
833

        outputs = (prediction_scores,) + outputs[2:]  # Add hidden states and attention if they are here

        return outputs  # prediction_scores, (hidden_states), (attentions)


834
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
835
    """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
836
)
thomwolf's avatar
thomwolf committed
837
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
Lysandre's avatar
Lysandre committed
838
839
840
841
842
843
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.bert = TFBertMainLayer(config, name="bert")
        self.nsp = TFBertNSPHead(config, name="nsp___cls")

844
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Lysandre's avatar
Lysandre committed
845
846
847
    def call(self, inputs, **kwargs):
        r"""
    Return:
848
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
849
        seq_relationship_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`)
thomwolf's avatar
thomwolf committed
850
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
Lysandre's avatar
Lysandre committed
851
852
853
854
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
855
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
856
857
858
859
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
860
861
862
863
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
864
        import tensorflow as tf
865
        from transformers import BertTokenizer, TFBertForNextSentencePrediction
thomwolf's avatar
thomwolf committed
866

thomwolf's avatar
thomwolf committed
867
868
869
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')

870
871
872
873
874
875
        prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        encoding = tokenizer.encode_plus(prompt, next_sentence, return_tensors='tf')

        logits = model(encoding['input_ids'], token_type_ids=encoding['token_type_ids'])[0]
        assert logits[0][0] < logits[0][1] # the next sentence was random
Lysandre's avatar
Lysandre committed
876
        """
thomwolf's avatar
thomwolf committed
877
        outputs = self.bert(inputs, **kwargs)
thomwolf's avatar
thomwolf committed
878
879

        pooled_output = outputs[1]
880
        seq_relationship_score = self.nsp(pooled_output)
thomwolf's avatar
thomwolf committed
881
882
883
884

        outputs = (seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
885
886


887
888
@add_start_docstrings(
    """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
889
    the pooled output) e.g. for GLUE tasks. """,
890
891
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
892
class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassificationLoss):
Lysandre's avatar
Lysandre committed
893
894
895
896
897
898
899
900
901
902
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Julien Plu's avatar
Julien Plu committed
903
904
905
906
907
908
909
910
911
912
913
914
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        training=False,
    ):
Lysandre's avatar
Lysandre committed
915
        r"""
Julien Plu's avatar
Julien Plu committed
916
917
918
919
920
921
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).

Lysandre's avatar
Lysandre committed
922
    Return:
923
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
924
        logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
thomwolf's avatar
thomwolf committed
925
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
Lysandre's avatar
Lysandre committed
926
927
928
929
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
930
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
931
932
933
934
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
935
936
937
938
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
939
        import tensorflow as tf
940
        from transformers import BertTokenizer, TFBertForSequenceClassification
thomwolf's avatar
thomwolf committed
941

thomwolf's avatar
thomwolf committed
942
943
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
Julien Plu's avatar
Julien Plu committed
944
945
946
947
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :]  # Batch size 1
        labels = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]
thomwolf's avatar
thomwolf committed
948

Lysandre's avatar
Lysandre committed
949
        """
Julien Plu's avatar
Julien Plu committed
950
951
952
953
954
955
956
957
958
959

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            training=training,
        )
thomwolf's avatar
thomwolf committed
960
961
962

        pooled_output = outputs[1]

Julien Plu's avatar
Julien Plu committed
963
        pooled_output = self.dropout(pooled_output, training=training)
thomwolf's avatar
thomwolf committed
964
965
966
967
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

Julien Plu's avatar
Julien Plu committed
968
969
970
971
972
        if labels is not None:
            loss = self.compute_loss(labels, logits)
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
973
974


975
976
@add_start_docstrings(
    """Bert Model with a multiple choice classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
977
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
978
979
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
980
class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
thomwolf's avatar
thomwolf committed
981
    def __init__(self, config, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
982
        super().__init__(config, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
983

984
        self.bert = TFBertMainLayer(config, name="bert")
thomwolf's avatar
thomwolf committed
985
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
986
987
988
989
        self.classifier = tf.keras.layers.Dense(
            1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Lysandre's avatar
Lysandre committed
990
991
992
993
994
995
996
997
998
    @property
    def dummy_inputs(self):
        """ Dummy inputs to build the network.

        Returns:
            tf.Tensor with dummy inputs
        """
        return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}

999
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1000
1001
1002
1003
1004
1005
1006
1007
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Julien Plu's avatar
Julien Plu committed
1008
        labels=None,
1009
1010
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1011
        r"""
Julien Plu's avatar
Julien Plu committed
1012
1013
1014
1015
1016
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)

Lysandre's avatar
Lysandre committed
1017
    Return:
1018
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
            `num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).

            Classification scores (before SoftMax).
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        import tensorflow as tf
        from transformers import BertTokenizer, TFBertForMultipleChoice

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
Julien Plu's avatar
Julien Plu committed
1041
        choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
Lysandre's avatar
Lysandre committed
1042

Julien Plu's avatar
Julien Plu committed
1043
1044
1045
1046
1047
        input_ids = tf.constant([tokenizer.encode(s, add_special_tokens=True) for s in choices])[None, :] # Batch size 1, 2 choices
        labels = tf.reshape(tf.constant(1), (-1, 1))
        outputs = model(input_ids, labels=labels)

        loss, classification_scores = outputs[:2]
1048

Lysandre's avatar
Lysandre committed
1049
        """
thomwolf's avatar
thomwolf committed
1050
        if isinstance(inputs, (tuple, list)):
thomwolf's avatar
thomwolf committed
1051
            input_ids = inputs[0]
thomwolf's avatar
thomwolf committed
1052
1053
1054
1055
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
1056
1057
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
            assert len(inputs) <= 6, "Too many inputs."
Julien Plu's avatar
Julien Plu committed
1058
        elif isinstance(inputs, (dict, BatchEncoding)):
1059
1060
1061
1062
1063
1064
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
1065
            assert len(inputs) <= 6, "Too many inputs."
thomwolf's avatar
thomwolf committed
1066
1067
        else:
            input_ids = inputs
thomwolf's avatar
thomwolf committed
1068

1069
        if input_ids is not None:
1070
1071
            num_choices = shape_list(input_ids)[1]
            seq_length = shape_list(input_ids)[2]
1072
        else:
1073
1074
            num_choices = shape_list(inputs_embeds)[1]
            seq_length = shape_list(inputs_embeds)[2]
thomwolf's avatar
thomwolf committed
1075

1076
        flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
thomwolf's avatar
thomwolf committed
1077
1078
1079
1080
        flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None

1081
1082
1083
1084
1085
1086
1087
1088
        flat_inputs = [
            flat_input_ids,
            flat_attention_mask,
            flat_token_type_ids,
            flat_position_ids,
            head_mask,
            inputs_embeds,
        ]
thomwolf's avatar
thomwolf committed
1089
1090
1091
1092
1093

        outputs = self.bert(flat_inputs, training=training)

        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
1094
        pooled_output = self.dropout(pooled_output, training=training)
thomwolf's avatar
thomwolf committed
1095
1096
1097
1098
1099
        logits = self.classifier(pooled_output)
        reshaped_logits = tf.reshape(logits, (-1, num_choices))

        outputs = (reshaped_logits,) + outputs[2:]  # add hidden states and attention if they are here

Julien Plu's avatar
Julien Plu committed
1100
1101
1102
1103
1104
        if labels is not None:
            loss = self.compute_loss(labels, reshaped_logits)
            outputs = (loss,) + outputs

        return outputs  # (loss), reshaped_logits, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1105
1106


1107
1108
@add_start_docstrings(
    """Bert Model with a token classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
1109
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1110
1111
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1112
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
Lysandre's avatar
Lysandre committed
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

Julien Plu's avatar
Julien Plu committed
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1135
        r"""
Julien Plu's avatar
Julien Plu committed
1136
1137
1138
1139
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.

Lysandre's avatar
Lysandre committed
1140
    Return:
1141
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
1142
        scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
thomwolf's avatar
thomwolf committed
1143
            Classification scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1144
1145
1146
1147
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
1148
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
1149
1150
1151
1152
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
1153
1154
1155
1156
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
1157
        import tensorflow as tf
1158
        from transformers import BertTokenizer, TFBertForTokenClassification
thomwolf's avatar
thomwolf committed
1159

thomwolf's avatar
thomwolf committed
1160
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
thomwolf's avatar
thomwolf committed
1161
        model = TFBertForTokenClassification.from_pretrained('bert-base-uncased')
1162
        input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]  # Batch size 1
Julien Plu's avatar
Julien Plu committed
1163
1164
1165
        labels = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, scores = outputs[:2]
thomwolf's avatar
thomwolf committed
1166

Lysandre's avatar
Lysandre committed
1167
        """
Julien Plu's avatar
Julien Plu committed
1168
1169
1170
1171
1172
1173
1174
1175
1176
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            training=training,
        )
thomwolf's avatar
thomwolf committed
1177
1178
1179

        sequence_output = outputs[0]

Julien Plu's avatar
Julien Plu committed
1180
        sequence_output = self.dropout(sequence_output, training=training)
thomwolf's avatar
thomwolf committed
1181
1182
1183
1184
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

Julien Plu's avatar
Julien Plu committed
1185
1186
1187
1188
1189
        if labels is not None:
            loss = self.compute_loss(labels, logits)
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1190
1191


1192
1193
@add_start_docstrings(
    """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
thomwolf's avatar
thomwolf committed
1194
    the hidden-states output to compute `span start logits` and `span end logits`). """,
1195
1196
    BERT_START_DOCSTRING,
)
Julien Plu's avatar
Julien Plu committed
1197
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
Lysandre's avatar
Lysandre committed
1198
1199
1200
1201
1202
1203
1204
1205
1206
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.qa_outputs = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
        )

Julien Plu's avatar
Julien Plu committed
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        cls_index=None,
        p_mask=None,
        is_impossible=None,
        training=False,
    ):
Lysandre's avatar
Lysandre committed
1223
        r"""
Julien Plu's avatar
Julien Plu committed
1224
1225
1226
1227
1228
1229
1230
1231
1232
        start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.

Lysandre's avatar
Lysandre committed
1233
    Return:
1234
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
1235
        start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
thomwolf's avatar
thomwolf committed
1236
            Span-start scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1237
        end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
thomwolf's avatar
thomwolf committed
1238
            Span-end scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1239
1240
1241
1242
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
1243
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
1244
1245
1246
1247
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:

thomwolf's avatar
thomwolf committed
1248
1249
1250
1251
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

thomwolf's avatar
thomwolf committed
1252
        import tensorflow as tf
1253
        from transformers import BertTokenizer, TFBertForQuestionAnswering
thomwolf's avatar
thomwolf committed
1254

thomwolf's avatar
thomwolf committed
1255
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1256
1257
        model = TFBertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
        question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
Julien Plu's avatar
Julien Plu committed
1258
1259
        input_dict = tokenizer.encode_plus(question, text, return_tensors='tf')
        start_scores, end_scores = model(input_dict)
1260

Julien Plu's avatar
Julien Plu committed
1261
1262
        all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
        answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
1263
        assert answer == "a nice puppet"
thomwolf's avatar
thomwolf committed
1264

Lysandre's avatar
Lysandre committed
1265
        """
Julien Plu's avatar
Julien Plu committed
1266
1267
1268
1269
1270
1271
1272
1273
1274
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            training=training,
        )
thomwolf's avatar
thomwolf committed
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)

        outputs = (start_logits, end_logits,) + outputs[2:]

Julien Plu's avatar
Julien Plu committed
1285
1286
1287
1288
1289
1290
1291
        if start_positions is not None and end_positions is not None:
            labels = {"start_position": start_positions}
            labels["end_position"] = end_positions
            loss = self.compute_loss(labels, outputs[:2])
            outputs = (loss,) + outputs

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)