modeling_bert.py 63.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
16
"""PyTorch BERT model. """
thomwolf's avatar
thomwolf committed
17
18
19


import logging
thomwolf's avatar
thomwolf committed
20
21
import math
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
22
import warnings
23
24
from dataclasses import dataclass
from typing import Optional, Tuple
thomwolf's avatar
thomwolf committed
25
26

import torch
27
import torch.utils.checkpoint
thomwolf's avatar
thomwolf committed
28
from torch import nn
29
from torch.nn import CrossEntropyLoss, MSELoss
thomwolf's avatar
thomwolf committed
30

31
from .activations import gelu, gelu_new, swish
32
from .configuration_bert import BertConfig
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from .file_utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
    replace_return_docstrings,
)
from .modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPooling,
    CausalLMOutput,
    MaskedLMOutput,
    MultipleChoiceModelOutput,
    NextSentencePredictorOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)
51
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
Aymeric Augustin's avatar
Aymeric Augustin committed
52

thomwolf's avatar
thomwolf committed
53
54
55

logger = logging.getLogger(__name__)

56
_CONFIG_FOR_DOC = "BertConfig"
57
58
_TOKENIZER_FOR_DOC = "BertTokenizer"

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "bert-base-uncased",
    "bert-large-uncased",
    "bert-base-cased",
    "bert-large-cased",
    "bert-base-multilingual-uncased",
    "bert-base-multilingual-cased",
    "bert-base-chinese",
    "bert-base-german-cased",
    "bert-large-uncased-whole-word-masking",
    "bert-large-cased-whole-word-masking",
    "bert-large-uncased-whole-word-masking-finetuned-squad",
    "bert-large-cased-whole-word-masking-finetuned-squad",
    "bert-base-cased-finetuned-mrpc",
    "bert-base-german-dbmdz-cased",
    "bert-base-german-dbmdz-uncased",
    "cl-tohoku/bert-base-japanese",
    "cl-tohoku/bert-base-japanese-whole-word-masking",
    "cl-tohoku/bert-base-japanese-char",
    "cl-tohoku/bert-base-japanese-char-whole-word-masking",
    "TurkuNLP/bert-base-finnish-cased-v1",
    "TurkuNLP/bert-base-finnish-uncased-v1",
    "wietsedv/bert-base-dutch-cased",
    # See all BERT models at https://huggingface.co/models?filter=bert
]
84

Rémi Louf's avatar
Rémi Louf committed
85

86
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
thomwolf's avatar
thomwolf committed
87
    """ Load tf checkpoints in a pytorch model.
88
    """
89
90
91
92
    try:
        import re
        import numpy as np
        import tensorflow as tf
thomwolf's avatar
thomwolf committed
93
    except ImportError:
94
95
96
97
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
98
        raise
99
    tf_path = os.path.abspath(tf_checkpoint_path)
thomwolf's avatar
thomwolf committed
100
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
101
102
103
104
105
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
106
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
107
108
109
110
111
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    for name, array in zip(names, arrays):
112
        name = name.split("/")
113
114
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
115
116
117
118
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
thomwolf's avatar
thomwolf committed
119
            logger.info("Skipping {}".format("/".join(name)))
120
121
122
            continue
        pointer = model
        for m_name in name:
123
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
124
                scope_names = re.split(r"_(\d+)", m_name)
125
            else:
126
127
                scope_names = [m_name]
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
128
                pointer = getattr(pointer, "weight")
129
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
130
                pointer = getattr(pointer, "bias")
131
            elif scope_names[0] == "output_weights":
132
                pointer = getattr(pointer, "weight")
133
            elif scope_names[0] == "squad":
134
                pointer = getattr(pointer, "classifier")
135
            else:
136
                try:
137
                    pointer = getattr(pointer, scope_names[0])
138
                except AttributeError:
thomwolf's avatar
thomwolf committed
139
                    logger.info("Skipping {}".format("/".join(name)))
140
                    continue
141
142
            if len(scope_names) >= 2:
                num = int(scope_names[1])
143
                pointer = pointer[num]
144
145
146
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
147
148
149
150
151
152
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
153
        logger.info("Initialize PyTorch weight {}".format(name))
154
155
156
157
        pointer.data = torch.from_numpy(array)
    return model


Diganta Misra's avatar
Diganta Misra committed
158
159
160
161
162
def mish(x):
    return x * torch.tanh(nn.functional.softplus(x))


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
thomwolf's avatar
thomwolf committed
163
164


165
BertLayerNorm = torch.nn.LayerNorm
thomwolf's avatar
thomwolf committed
166

Rémi Louf's avatar
Rémi Louf committed
167

thomwolf's avatar
thomwolf committed
168
169
170
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
171

thomwolf's avatar
thomwolf committed
172
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
173
        super().__init__()
174
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
175
176
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
177
178
179

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
180
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
181
182
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

183
184
185
        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

186
187
188
189
190
191
192
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
193

thomwolf's avatar
thomwolf committed
194
        if position_ids is None:
195
196
            position_ids = self.position_ids[:, :seq_length]

thomwolf's avatar
thomwolf committed
197
        if token_type_ids is None:
198
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
thomwolf's avatar
thomwolf committed
199

200
201
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
thomwolf's avatar
thomwolf committed
202
203
204
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

205
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
thomwolf's avatar
thomwolf committed
206
207
208
209
210
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


Rémi Louf's avatar
Rémi Louf committed
211
212
class BertSelfAttention(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
213
        super().__init__()
Lysandre Debut's avatar
Lysandre Debut committed
214
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
Rémi Louf's avatar
Rémi Louf committed
215
216
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
217
218
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
Rémi Louf's avatar
Rémi Louf committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

235
236
237
238
239
240
241
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
242
        output_attentions=False,
243
    ):
thomwolf's avatar
thomwolf committed
244
        mixed_query_layer = self.query(hidden_states)
245

246
247
248
        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
thomwolf's avatar
thomwolf committed
249
250
251
        if encoder_hidden_states is not None:
            mixed_key_layer = self.key(encoder_hidden_states)
            mixed_value_layer = self.value(encoder_hidden_states)
252
            attention_mask = encoder_attention_mask
Rémi Louf's avatar
Rémi Louf committed
253
        else:
thomwolf's avatar
thomwolf committed
254
255
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)
Rémi Louf's avatar
Rémi Louf committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

285
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
Rémi Louf's avatar
Rémi Louf committed
286
287
288
        return outputs


thomwolf's avatar
thomwolf committed
289
290
class BertSelfOutput(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
291
        super().__init__()
thomwolf's avatar
thomwolf committed
292
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
293
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
294
295
296
297
298
299
300
301
302
303
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
thomwolf's avatar
thomwolf committed
304
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
305
        super().__init__()
thomwolf's avatar
thomwolf committed
306
        self.self = BertSelfAttention(config)
thomwolf's avatar
thomwolf committed
307
        self.output = BertSelfOutput(config)
308
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
309

thomwolf's avatar
thomwolf committed
310
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
311
312
        if len(heads) == 0:
            return
313
314
315
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )
316

thomwolf's avatar
thomwolf committed
317
318
319
320
        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
thomwolf's avatar
thomwolf committed
321
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
322
323

        # Update hyper params and store pruned heads
thomwolf's avatar
thomwolf committed
324
325
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
326
        self.pruned_heads = self.pruned_heads.union(heads)
thomwolf's avatar
thomwolf committed
327

328
329
330
331
332
333
334
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
335
        output_attentions=False,
336
337
    ):
        self_outputs = self.self(
338
            hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
339
        )
Rémi Louf's avatar
Rémi Louf committed
340
341
342
343
344
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


thomwolf's avatar
thomwolf committed
345
346
class BertIntermediate(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
347
        super().__init__()
thomwolf's avatar
thomwolf committed
348
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
349
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
350
351
352
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
thomwolf's avatar
thomwolf committed
353
354
355
356
357
358
359
360
361

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
362
        super().__init__()
thomwolf's avatar
thomwolf committed
363
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
364
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
365
366
367
368
369
370
371
372
373
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


Rémi Louf's avatar
Rémi Louf committed
374
class BertLayer(nn.Module):
thomwolf's avatar
thomwolf committed
375
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
376
        super().__init__()
377
        self.attention = BertAttention(config)
thomwolf's avatar
thomwolf committed
378
379
        self.is_decoder = config.is_decoder
        if self.is_decoder:
380
            self.crossattention = BertAttention(config)
Rémi Louf's avatar
Rémi Louf committed
381
382
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
383

384
385
386
387
388
389
390
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
391
        output_attentions=False,
392
    ):
393
394
395
        self_attention_outputs = self.attention(
            hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
        )
thomwolf's avatar
thomwolf committed
396
397
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
Rémi Louf's avatar
Rémi Louf committed
398

399
        if self.is_decoder and encoder_hidden_states is not None:
400
            cross_attention_outputs = self.crossattention(
401
402
403
404
405
406
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                output_attentions,
407
            )
thomwolf's avatar
thomwolf committed
408
409
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights
Rémi Louf's avatar
Rémi Louf committed
410

Rémi Louf's avatar
Rémi Louf committed
411
        intermediate_output = self.intermediate(attention_output)
Rémi Louf's avatar
Rémi Louf committed
412
        layer_output = self.output(intermediate_output, attention_output)
thomwolf's avatar
thomwolf committed
413
        outputs = (layer_output,) + outputs
Rémi Louf's avatar
Rémi Louf committed
414
        return outputs
415
416


thomwolf's avatar
thomwolf committed
417
class BertEncoder(nn.Module):
thomwolf's avatar
thomwolf committed
418
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
419
        super().__init__()
420
        self.config = config
Rémi Louf's avatar
Rémi Louf committed
421
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
thomwolf's avatar
thomwolf committed
422

423
424
425
426
427
428
429
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
430
        output_attentions=False,
Joseph Liu's avatar
Joseph Liu committed
431
        output_hidden_states=False,
432
        return_tuple=False,
433
    ):
434
435
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
436
        for i, layer_module in enumerate(self.layer):
Joseph Liu's avatar
Joseph Liu committed
437
            if output_hidden_states:
438
                all_hidden_states = all_hidden_states + (hidden_states,)
439

440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
            if getattr(self.config, "gradient_checkpointing", False):

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    head_mask[i],
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    head_mask[i],
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions,
                )
465
            hidden_states = layer_outputs[0]
466
            if output_attentions:
467
                all_attentions = all_attentions + (layer_outputs[1],)
468

Joseph Liu's avatar
Joseph Liu committed
469
        if output_hidden_states:
470
            all_hidden_states = all_hidden_states + (hidden_states,)
471

472
473
474
475
476
        if return_tuple:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
thomwolf's avatar
thomwolf committed
477
478
479
480


class BertPooler(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
481
        super().__init__()
thomwolf's avatar
thomwolf committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
496
        super().__init__()
thomwolf's avatar
thomwolf committed
497
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
498
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
499
500
501
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
502
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
503
504
505
506
507
508
509
510
511

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
thomwolf's avatar
thomwolf committed
512
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
513
        super().__init__()
thomwolf's avatar
thomwolf committed
514
515
516
517
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
518
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
519

thomwolf's avatar
thomwolf committed
520
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
thomwolf's avatar
thomwolf committed
521

522
523
524
        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

thomwolf's avatar
thomwolf committed
525
526
    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
Lysandre Debut's avatar
Lysandre Debut committed
527
        hidden_states = self.decoder(hidden_states)
thomwolf's avatar
thomwolf committed
528
529
530
531
        return hidden_states


class BertOnlyMLMHead(nn.Module):
thomwolf's avatar
thomwolf committed
532
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
533
        super().__init__()
thomwolf's avatar
thomwolf committed
534
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
535
536
537
538
539
540
541
542

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertOnlyNSPHead(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
543
        super().__init__()
thomwolf's avatar
thomwolf committed
544
545
546
547
548
549
550
551
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class BertPreTrainingHeads(nn.Module):
thomwolf's avatar
thomwolf committed
552
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
553
        super().__init__()
thomwolf's avatar
thomwolf committed
554
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
555
556
557
558
559
560
561
562
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


563
class BertPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
564
    """ An abstract class to handle weights initialization and
565
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
566
    """
567

568
569
570
    config_class = BertConfig
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"
571
    authorized_missing_keys = [r"position_ids"]
572

573
574
    def _init_weights(self, module):
        """ Initialize the weights """
thomwolf's avatar
thomwolf committed
575
576
577
578
579
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
Li Dong's avatar
Li Dong committed
580
581
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
thomwolf's avatar
thomwolf committed
582
583
584
585
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


586
587
588
589
590
591
@dataclass
class BertForPretrainingOutput(ModelOutput):
    """
    Output type of :class:`~transformers.BertForPretrainingModel`.

    Args:
592
        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
593
            Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
594
        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False
            continuation before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[torch.FloatTensor]
    prediction_logits: torch.FloatTensor
    seq_relationship_logits: torch.FloatTensor
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


Lysandre's avatar
Lysandre committed
619
620
621
BERT_START_DOCSTRING = r"""
    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
Lysandre's avatar
Fixes  
Lysandre committed
622
    usage and behavior.
thomwolf's avatar
thomwolf committed
623

thomwolf's avatar
thomwolf committed
624
    Parameters:
Rémi Louf's avatar
Rémi Louf committed
625
        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
626
            Initializing with a config file does not load the weights associated with the model, only the configuration.
627
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
628
629
630
"""

BERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
631
    Args:
632
        input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Lysandre's avatar
Lysandre committed
633
634
            Indices of input sequence tokens in the vocabulary.

635
636
            Indices can be obtained using :class:`transformers.BertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
637
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
Lysandre's avatar
Lysandre committed
638

Lysandre's avatar
Lysandre committed
639
            `What are input IDs? <../glossary.html#input-ids>`__
640
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
641
642
643
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
644

Lysandre's avatar
Lysandre committed
645
            `What are attention masks? <../glossary.html#attention-mask>`__
646
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
647
648
649
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
Lysandre's avatar
Lysandre committed
650

Lysandre's avatar
Lysandre committed
651
            `What are token type IDs? <../glossary.html#token-type-ids>`_
652
        position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
653
654
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
655

Lysandre's avatar
Lysandre committed
656
657
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
658
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
659
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
Lysandre committed
660
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
Lysandre's avatar
Lysandre committed
661
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
Lysandre committed
662
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
663
664
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
Lysandre's avatar
Lysandre committed
665
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
Lysandre committed
666
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
Lysandre's avatar
Lysandre committed
667
668
            if the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
669
670
671
672
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask
            is used in the cross-attention if the model is configured as a decoder.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
ZhuBaohe's avatar
ZhuBaohe committed
673
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
674
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
Quentin Lhoest's avatar
Quentin Lhoest committed
675
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
676
677
678
            If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
        return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
thomwolf's avatar
thomwolf committed
679
680
"""

681
682
683
684
685

@add_start_docstrings(
    "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
686
class BertModel(BertPreTrainedModel):
Lysandre's avatar
Lysandre committed
687
    """
thomwolf's avatar
thomwolf committed
688

Lysandre's avatar
Lysandre committed
689
690
691
692
    The model can behave as an encoder (with only self-attention) as well
    as a decoder, in which case a layer of cross-attention is added between
    the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
    Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
thomwolf's avatar
thomwolf committed
693

Lysandre's avatar
Lysandre committed
694
695
696
697
698
699
    To behave as an decoder the model needs to be initialized with the
    :obj:`is_decoder` argument of the configuration set to :obj:`True`; an
    :obj:`encoder_hidden_states` is expected as an input to the forward pass.

    .. _`Attention is all you need`:
        https://arxiv.org/abs/1706.03762
thomwolf's avatar
thomwolf committed
700
701

    """
702

thomwolf's avatar
thomwolf committed
703
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
704
        super().__init__(config)
705
        self.config = config
thomwolf's avatar
thomwolf committed
706

thomwolf's avatar
thomwolf committed
707
        self.embeddings = BertEmbeddings(config)
thomwolf's avatar
thomwolf committed
708
        self.encoder = BertEncoder(config)
thomwolf's avatar
thomwolf committed
709
        self.pooler = BertPooler(config)
thomwolf's avatar
thomwolf committed
710

711
        self.init_weights()
thomwolf's avatar
thomwolf committed
712

thomwolf's avatar
thomwolf committed
713
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
714
        return self.embeddings.word_embeddings
thomwolf's avatar
thomwolf committed
715

thomwolf's avatar
thomwolf committed
716
717
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
718

thomwolf's avatar
thomwolf committed
719
    def _prune_heads(self, heads_to_prune):
thomwolf's avatar
thomwolf committed
720
721
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
thomwolf's avatar
thomwolf committed
722
            See base class PreTrainedModel
thomwolf's avatar
thomwolf committed
723
724
        """
        for layer, heads in heads_to_prune.items():
725
            self.encoder.layer[layer].attention.prune_heads(heads)
thomwolf's avatar
thomwolf committed
726

727
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
728
729
730
731
732
733
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-uncased",
        output_type=BaseModelOutputWithPooling,
        config_class=_CONFIG_FOR_DOC,
    )
734
735
736
737
738
739
740
741
742
743
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
744
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
745
        output_hidden_states=None,
746
        return_tuple=None,
747
    ):
748
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
749
750
751
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
752
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
Lysandre's avatar
Lysandre committed
753

754
755
756
757
758
759
760
761
762
763
764
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

thomwolf's avatar
thomwolf committed
765
        if attention_mask is None:
Julien Chaumond's avatar
Julien Chaumond committed
766
            attention_mask = torch.ones(input_shape, device=device)
thomwolf's avatar
thomwolf committed
767
        if token_type_ids is None:
Julien Chaumond's avatar
Julien Chaumond committed
768
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
thomwolf's avatar
thomwolf committed
769

770
771
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
772
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
thomwolf's avatar
thomwolf committed
773

Rémi Louf's avatar
Rémi Louf committed
774
        # If a 2D ou 3D attention mask is provided for the cross-attention
Rémi Louf's avatar
Rémi Louf committed
775
        # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
776
777
778
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
779
            if encoder_attention_mask is None:
780
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
781
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
782
783
        else:
            encoder_extended_attention_mask = None
Rémi Louf's avatar
Rémi Louf committed
784

thomwolf's avatar
thomwolf committed
785
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
786
        # 1.0 in head_mask indicate we keep the head
thomwolf's avatar
thomwolf committed
787
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
788
789
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
790
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
thomwolf's avatar
thomwolf committed
791

792
793
794
795
796
797
798
799
800
        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
801
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
802
            output_hidden_states=output_hidden_states,
803
            return_tuple=return_tuple,
804
        )
805
        sequence_output = encoder_outputs[0]
thomwolf's avatar
thomwolf committed
806
        pooled_output = self.pooler(sequence_output)
807

808
809
810
811
812
813
814
815
816
        if return_tuple:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
817
818


819
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
820
    """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
Lysandre's avatar
Lysandre committed
821
    a `next sentence prediction (classification)` head. """,
822
823
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
824
class BertForPreTraining(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
825
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
826
        super().__init__(config)
827

thomwolf's avatar
thomwolf committed
828
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
829
        self.cls = BertPreTrainingHeads(config)
thomwolf's avatar
thomwolf committed
830

831
        self.init_weights()
thomwolf's avatar
thomwolf committed
832

thomwolf's avatar
thomwolf committed
833
    def get_output_embeddings(self):
834
        return self.cls.predictions.decoder
thomwolf's avatar
thomwolf committed
835

836
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
837
    @replace_return_docstrings(output_type=BertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
838
839
840
841
842
843
844
845
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
846
        labels=None,
847
        next_sentence_label=None,
848
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
849
        output_hidden_states=None,
850
        return_tuple=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
851
        **kwargs
852
    ):
Lysandre's avatar
Lysandre committed
853
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
854
        labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
Lysandre's avatar
Lysandre committed
855
            Labels for computing the masked language modeling loss.
Lysandre's avatar
Lysandre committed
856
857
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
Lysandre's avatar
Lysandre committed
858
859
860
861
862
863
            in ``[0, ..., config.vocab_size]``
        next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.
Sylvain Gugger's avatar
Sylvain Gugger committed
864
865
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
Lysandre's avatar
Lysandre committed
866
867
868
869
870

    Returns:

    Examples::

871
872
        >>> from transformers import BertTokenizer, BertForPreTraining
        >>> import torch
Lysandre's avatar
Lysandre committed
873

874
875
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
Lysandre's avatar
Lysandre committed
876

877
878
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
Lysandre's avatar
Lysandre committed
879

880
881
        >>> prediction_logits = outptus.prediction_logits
        >>> seq_relationship_logits = outputs.seq_relationship_logits
Lysandre's avatar
Lysandre committed
882
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
883
884
885
        if "masked_lm_labels" in kwargs:
            warnings.warn(
                "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
886
                FutureWarning,
Sylvain Gugger's avatar
Sylvain Gugger committed
887
888
889
            )
            labels = kwargs.pop("masked_lm_labels")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
890
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
891
892
893
894
895
896
897
898

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
899
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
900
            output_hidden_states=output_hidden_states,
901
            return_tuple=return_tuple,
902
        )
903
904

        sequence_output, pooled_output = outputs[:2]
thomwolf's avatar
thomwolf committed
905
906
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

907
        total_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
908
        if labels is not None and next_sentence_label is not None:
LysandreJik's avatar
LysandreJik committed
909
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
910
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
911
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
thomwolf's avatar
thomwolf committed
912
            total_loss = masked_lm_loss + next_sentence_loss
913

914
915
916
917
918
919
920
921
922
923
924
        if return_tuple:
            output = (prediction_scores, seq_relationship_score) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return BertForPretrainingOutput(
            loss=total_loss,
            prediction_logits=prediction_scores,
            seq_relationship_logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
925
926


927
928
929
930
@add_start_docstrings(
    """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
931
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
932
        super().__init__(config)
933
        assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
thomwolf's avatar
thomwolf committed
934

thomwolf's avatar
thomwolf committed
935
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
936
        self.cls = BertOnlyMLMHead(config)
thomwolf's avatar
thomwolf committed
937

938
        self.init_weights()
thomwolf's avatar
thomwolf committed
939

thomwolf's avatar
thomwolf committed
940
    def get_output_embeddings(self):
941
        return self.cls.predictions.decoder
thomwolf's avatar
thomwolf committed
942

943
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
944
    @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
945
946
947
948
949
950
951
952
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
953
        labels=None,
954
955
        encoder_hidden_states=None,
        encoder_attention_mask=None,
956
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
957
        output_hidden_states=None,
958
        return_tuple=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
959
        **kwargs
960
    ):
Lysandre's avatar
Lysandre committed
961
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
962
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
963
            Labels for computing the left-to-right language modeling loss (next word prediction).
Lysandre's avatar
Lysandre committed
964
965
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
Lysandre's avatar
Lysandre committed
966
            in ``[0, ..., config.vocab_size]``
967
968
969
970
971
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.

    Returns:

972
    Example::
973

974
975
        >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
        >>> import torch
976

977
978
979
980
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        >>> config = BertConfig.from_pretrained("bert-base-cased")
        >>> config.is_decoder = True
        >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
981

982
983
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
984

985
        >>> prediction_logits = outputs.logits
986
        """
987
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
988
989
990
991
992
993
994
995
996
997
998

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
999
            output_hidden_states=output_hidden_states,
1000
            return_tuple=return_tuple,
1001
1002
1003
1004
1005
        )

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)

1006
        lm_loss = None
1007
1008
        if labels is not None:
            # we are doing next-token prediction; shift prediction scores and input ids by one
1009
            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1010
1011
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss()
1012
1013
1014
1015
1016
            lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

        if return_tuple:
            output = (prediction_scores,) + outputs[2:]
            return ((lm_loss,) + output) if lm_loss is not None else output
1017

1018
1019
1020
        return CausalLMOutput(
            loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
        )
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035

    def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
        input_shape = input_ids.shape

        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        return {"input_ids": input_ids, "attention_mask": attention_mask}


@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
1036
1037
1038
        assert (
            not config.is_decoder
        ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048

        self.bert = BertModel(config)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1049
1050
1051
1052
1053
1054
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-uncased",
        output_type=MaskedLMOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1067
        output_hidden_states=None,
1068
        return_tuple=None,
1069
1070
1071
1072
1073
        **kwargs
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the masked language modeling loss.
Lysandre's avatar
Lysandre committed
1074
1075
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
Lysandre's avatar
Lysandre committed
1076
            in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
1077
1078
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
Lysandre's avatar
Lysandre committed
1079
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1080
1081
1082
        if "masked_lm_labels" in kwargs:
            warnings.warn(
                "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1083
                FutureWarning,
Sylvain Gugger's avatar
Sylvain Gugger committed
1084
1085
            )
            labels = kwargs.pop("masked_lm_labels")
1086
        assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
Sylvain Gugger's avatar
Sylvain Gugger committed
1087
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
1088

1089
1090
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple

1091
1092
1093
1094
1095
1096
1097
1098
1099
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
1100
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1101
            output_hidden_states=output_hidden_states,
1102
            return_tuple=return_tuple,
1103
        )
thomwolf's avatar
thomwolf committed
1104
1105

        sequence_output = outputs[0]
thomwolf's avatar
thomwolf committed
1106
1107
        prediction_scores = self.cls(sequence_output)

1108
        masked_lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1109
        if labels is not None:
Lysandre's avatar
Lysandre committed
1110
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
Sylvain Gugger's avatar
Sylvain Gugger committed
1111
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
thomwolf's avatar
thomwolf committed
1112

1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        if return_tuple:
            output = (prediction_scores,) + outputs[2:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1123

1124
1125
1126
1127
    def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
        input_shape = input_ids.shape
        effective_batch_size = input_shape[0]

1128
1129
1130
1131
1132
1133
1134
        #  add a dummy token
        assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
        attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
        dummy_token = torch.full(
            (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
        )
        input_ids = torch.cat([input_ids, dummy_token], dim=1)
1135
1136
1137

        return {"input_ids": input_ids, "attention_mask": attention_mask}

thomwolf's avatar
thomwolf committed
1138

1139
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1140
    """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
1141
)
thomwolf's avatar
thomwolf committed
1142
class BertForNextSentencePrediction(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1143
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1144
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1145

thomwolf's avatar
thomwolf committed
1146
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
1147
        self.cls = BertOnlyNSPHead(config)
thomwolf's avatar
thomwolf committed
1148

1149
        self.init_weights()
thomwolf's avatar
thomwolf committed
1150

1151
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1152
    @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1153
1154
1155
1156
1157
1158
1159
1160
1161
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        next_sentence_label=None,
1162
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1163
        output_hidden_states=None,
1164
        return_tuple=None,
1165
    ):
Lysandre's avatar
Lysandre committed
1166
1167
1168
1169
1170
1171
1172
1173
1174
        r"""
        next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Returns:

1175
    Example::
Lysandre's avatar
Lysandre committed
1176

1177
1178
        >>> from transformers import BertTokenizer, BertForNextSentencePrediction
        >>> import torch
Lysandre's avatar
Lysandre committed
1179

1180
1181
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
Lysandre's avatar
Lysandre committed
1182

1183
1184
1185
        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
Lysandre's avatar
Lysandre committed
1186

1187
        >>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
1188
        >>> logits = outputs.logits
1189
        >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
Lysandre's avatar
Lysandre committed
1190
        """
1191
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
1192
1193
1194
1195
1196
1197
1198
1199

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1200
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1201
            output_hidden_states=output_hidden_states,
1202
            return_tuple=return_tuple,
1203
        )
1204

thomwolf's avatar
thomwolf committed
1205
1206
        pooled_output = outputs[1]

1207
        seq_relationship_scores = self.cls(pooled_output)
thomwolf's avatar
thomwolf committed
1208

1209
        next_sentence_loss = None
thomwolf's avatar
thomwolf committed
1210
        if next_sentence_label is not None:
LysandreJik's avatar
LysandreJik committed
1211
            loss_fct = CrossEntropyLoss()
1212
            next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), next_sentence_label.view(-1))
thomwolf's avatar
thomwolf committed
1213

1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        if return_tuple:
            output = (seq_relationship_scores,) + outputs[2:]
            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

        return NextSentencePredictorOutput(
            loss=next_sentence_loss,
            logits=seq_relationship_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1224
1225


1226
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1227
    """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
Lysandre's avatar
Lysandre committed
1228
    the pooled output) e.g. for GLUE tasks. """,
1229
1230
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1231
class BertForSequenceClassification(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1232
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1233
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1234
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1235

thomwolf's avatar
thomwolf committed
1236
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
1237
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Jordan's avatar
Jordan committed
1238
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1239

1240
        self.init_weights()
thomwolf's avatar
thomwolf committed
1241

1242
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1243
1244
1245
1246
1247
1248
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-uncased",
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1249
1250
1251
1252
1253
1254
1255
1256
1257
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1258
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1259
        output_hidden_states=None,
1260
        return_tuple=None,
1261
    ):
Lysandre's avatar
Lysandre committed
1262
1263
1264
1265
1266
1267
1268
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
1269
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
1270
1271
1272
1273
1274
1275
1276
1277

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1278
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1279
            output_hidden_states=output_hidden_states,
1280
            return_tuple=return_tuple,
1281
        )
1282

thomwolf's avatar
thomwolf committed
1283
1284
        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
1285
1286
1287
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1288
        loss = None
thomwolf's avatar
thomwolf committed
1289
        if labels is not None:
1290
1291
1292
1293
1294
1295
1296
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
thomwolf's avatar
thomwolf committed
1297

1298
1299
1300
1301
1302
1303
1304
        if return_tuple:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
        )
1305
1306


1307
1308
@add_start_docstrings(
    """Bert Model with a multiple choice classification head on top (a linear layer on top of
Lysandre's avatar
Lysandre committed
1309
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1310
1311
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1312
class BertForMultipleChoice(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1313
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1314
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1315

thomwolf's avatar
thomwolf committed
1316
        self.bert = BertModel(config)
1317
1318
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
thomwolf's avatar
thomwolf committed
1319

1320
        self.init_weights()
1321

1322
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1323
1324
1325
1326
1327
1328
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-uncased",
        output_type=MultipleChoiceModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1329
1330
1331
1332
1333
1334
1335
1336
1337
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1338
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1339
        output_hidden_states=None,
1340
        return_tuple=None,
1341
    ):
Lysandre's avatar
Lysandre committed
1342
1343
1344
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the multiple choice classification loss.
1345
            Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
Lysandre's avatar
Lysandre committed
1346
1347
            of the input tensors. (see `input_ids` above)
        """
1348
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
1349
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
thomwolf's avatar
thomwolf committed
1350

1351
        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1352
1353
1354
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1355
1356
1357
1358
1359
        inputs_embeds = (
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
            if inputs_embeds is not None
            else None
        )
1360

1361
1362
1363
1364
1365
1366
1367
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1368
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1369
            output_hidden_states=output_hidden_states,
1370
            return_tuple=return_tuple,
1371
        )
1372

thomwolf's avatar
thomwolf committed
1373
1374
        pooled_output = outputs[1]

1375
1376
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
thomwolf's avatar
thomwolf committed
1377
        reshaped_logits = logits.view(-1, num_choices)
1378

1379
        loss = None
1380
1381
1382
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)
thomwolf's avatar
thomwolf committed
1383

1384
1385
1386
1387
1388
1389
1390
        if return_tuple:
            output = (reshaped_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MultipleChoiceModelOutput(
            loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
        )
1391
1392


1393
1394
@add_start_docstrings(
    """Bert Model with a token classification head on top (a linear layer on top of
Lysandre's avatar
Lysandre committed
1395
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1396
1397
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1398
class BertForTokenClassification(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1399
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1400
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1401
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1402

thomwolf's avatar
thomwolf committed
1403
        self.bert = BertModel(config)
1404
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
1405
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1406

1407
        self.init_weights()
1408

1409
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1410
1411
1412
1413
1414
1415
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-uncased",
        output_type=TokenClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
    )
1416
1417
1418
1419
1420
1421
1422
1423
1424
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1425
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1426
        output_hidden_states=None,
1427
        return_tuple=None,
1428
    ):
Lysandre's avatar
Lysandre committed
1429
1430
1431
1432
1433
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.
        """
1434
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
1435
1436
1437
1438
1439
1440
1441
1442

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1443
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1444
            output_hidden_states=output_hidden_states,
1445
            return_tuple=return_tuple,
1446
        )
1447

thomwolf's avatar
thomwolf committed
1448
1449
        sequence_output = outputs[0]

1450
1451
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
1452

1453
        loss = None
1454
1455
        if labels is not None:
            loss_fct = CrossEntropyLoss()
1456
1457
1458
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
1459
1460
1461
1462
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
1463
1464
1465
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
thomwolf's avatar
thomwolf committed
1466

1467
1468
1469
1470
1471
1472
1473
        if return_tuple:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1474
1475


1476
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1477
    """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
Lysandre's avatar
Lysandre committed
1478
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
1479
1480
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1481
class BertForQuestionAnswering(BertPreTrainedModel):
Lysandre's avatar
Lysandre committed
1482
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1483
        super().__init__(config)
Lysandre's avatar
Lysandre committed
1484
1485
1486
1487
1488
1489
1490
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

1491
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1492
1493
1494
1495
1496
1497
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="bert-base-uncased",
        output_type=QuestionAnsweringModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
Lysandre's avatar
Lysandre committed
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
1508
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1509
        output_hidden_states=None,
1510
        return_tuple=None,
Lysandre's avatar
Lysandre committed
1511
1512
1513
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
1514
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1515
1516
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
Lysandre's avatar
Lysandre committed
1517
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
1518
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1519
1520
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
Lysandre's avatar
Lysandre committed
1521
        """
1522
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
1523
1524
1525
1526
1527
1528
1529
1530

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1531
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1532
            output_hidden_states=output_hidden_states,
1533
            return_tuple=return_tuple,
1534
        )
1535

thomwolf's avatar
thomwolf committed
1536
1537
        sequence_output = outputs[0]

thomwolf's avatar
thomwolf committed
1538
1539
1540
1541
1542
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1543
        total_loss = None
thomwolf's avatar
thomwolf committed
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
thomwolf's avatar
thomwolf committed
1559

1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
        if return_tuple:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )