"vscode:/vscode.git/clone" did not exist on "a1a34657d41627b21dddf2bf9cc55941329a60b6"
modeling_bert.py 68.2 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
16
"""PyTorch BERT model. """
thomwolf's avatar
thomwolf committed
17
18
19


import logging
thomwolf's avatar
thomwolf committed
20
21
import math
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
22
import warnings
thomwolf's avatar
thomwolf committed
23
24
25

import torch
from torch import nn
26
from torch.nn import CrossEntropyLoss, MSELoss
thomwolf's avatar
thomwolf committed
27

28
from .activations import gelu, gelu_new, swish
29
from .configuration_bert import BertConfig
Lysandre's avatar
Lysandre committed
30
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
31
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
Aymeric Augustin's avatar
Aymeric Augustin committed
32

thomwolf's avatar
thomwolf committed
33
34
35

logger = logging.getLogger(__name__)

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "bert-base-uncased",
    "bert-large-uncased",
    "bert-base-cased",
    "bert-large-cased",
    "bert-base-multilingual-uncased",
    "bert-base-multilingual-cased",
    "bert-base-chinese",
    "bert-base-german-cased",
    "bert-large-uncased-whole-word-masking",
    "bert-large-cased-whole-word-masking",
    "bert-large-uncased-whole-word-masking-finetuned-squad",
    "bert-large-cased-whole-word-masking-finetuned-squad",
    "bert-base-cased-finetuned-mrpc",
    "bert-base-german-dbmdz-cased",
    "bert-base-german-dbmdz-uncased",
    "cl-tohoku/bert-base-japanese",
    "cl-tohoku/bert-base-japanese-whole-word-masking",
    "cl-tohoku/bert-base-japanese-char",
    "cl-tohoku/bert-base-japanese-char-whole-word-masking",
    "TurkuNLP/bert-base-finnish-cased-v1",
    "TurkuNLP/bert-base-finnish-uncased-v1",
    "wietsedv/bert-base-dutch-cased",
    # See all BERT models at https://huggingface.co/models?filter=bert
]
61

Rémi Louf's avatar
Rémi Louf committed
62

63
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
thomwolf's avatar
thomwolf committed
64
    """ Load tf checkpoints in a pytorch model.
65
    """
66
67
68
69
    try:
        import re
        import numpy as np
        import tensorflow as tf
thomwolf's avatar
thomwolf committed
70
    except ImportError:
71
72
73
74
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
75
        raise
76
    tf_path = os.path.abspath(tf_checkpoint_path)
thomwolf's avatar
thomwolf committed
77
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
78
79
80
81
82
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
83
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
84
85
86
87
88
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    for name, array in zip(names, arrays):
89
        name = name.split("/")
90
91
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
92
93
94
95
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
thomwolf's avatar
thomwolf committed
96
            logger.info("Skipping {}".format("/".join(name)))
97
98
99
            continue
        pointer = model
        for m_name in name:
100
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
101
                scope_names = re.split(r"_(\d+)", m_name)
102
            else:
103
104
                scope_names = [m_name]
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
105
                pointer = getattr(pointer, "weight")
106
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
107
                pointer = getattr(pointer, "bias")
108
            elif scope_names[0] == "output_weights":
109
                pointer = getattr(pointer, "weight")
110
            elif scope_names[0] == "squad":
111
                pointer = getattr(pointer, "classifier")
112
            else:
113
                try:
114
                    pointer = getattr(pointer, scope_names[0])
115
                except AttributeError:
thomwolf's avatar
thomwolf committed
116
                    logger.info("Skipping {}".format("/".join(name)))
117
                    continue
118
119
            if len(scope_names) >= 2:
                num = int(scope_names[1])
120
                pointer = pointer[num]
121
122
123
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
124
125
126
127
128
129
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
130
        logger.info("Initialize PyTorch weight {}".format(name))
131
132
133
134
        pointer.data = torch.from_numpy(array)
    return model


Diganta Misra's avatar
Diganta Misra committed
135
136
137
138
139
def mish(x):
    return x * torch.tanh(nn.functional.softplus(x))


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
thomwolf's avatar
thomwolf committed
140
141


142
BertLayerNorm = torch.nn.LayerNorm
thomwolf's avatar
thomwolf committed
143

Rémi Louf's avatar
Rémi Louf committed
144

thomwolf's avatar
thomwolf committed
145
146
147
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
148

thomwolf's avatar
thomwolf committed
149
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
150
        super().__init__()
151
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
152
153
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
154
155
156

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
157
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
158
159
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

160
161
162
163
164
165
166
167
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        device = input_ids.device if input_ids is not None else inputs_embeds.device
thomwolf's avatar
thomwolf committed
168
        if position_ids is None:
169
170
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
thomwolf's avatar
thomwolf committed
171
        if token_type_ids is None:
Julien Chaumond's avatar
Julien Chaumond committed
172
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
thomwolf's avatar
thomwolf committed
173

174
175
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
thomwolf's avatar
thomwolf committed
176
177
178
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

179
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
thomwolf's avatar
thomwolf committed
180
181
182
183
184
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


Rémi Louf's avatar
Rémi Louf committed
185
186
class BertSelfAttention(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
187
        super().__init__()
Lysandre Debut's avatar
Lysandre Debut committed
188
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
Rémi Louf's avatar
Rémi Louf committed
189
190
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
191
192
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
Rémi Louf's avatar
Rémi Louf committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

209
210
211
212
213
214
215
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
216
        output_attentions=False,
217
    ):
thomwolf's avatar
thomwolf committed
218
        mixed_query_layer = self.query(hidden_states)
219

220
221
222
        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
thomwolf's avatar
thomwolf committed
223
224
225
        if encoder_hidden_states is not None:
            mixed_key_layer = self.key(encoder_hidden_states)
            mixed_value_layer = self.value(encoder_hidden_states)
226
            attention_mask = encoder_attention_mask
Rémi Louf's avatar
Rémi Louf committed
227
        else:
thomwolf's avatar
thomwolf committed
228
229
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)
Rémi Louf's avatar
Rémi Louf committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

259
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
Rémi Louf's avatar
Rémi Louf committed
260
261
262
        return outputs


thomwolf's avatar
thomwolf committed
263
264
class BertSelfOutput(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
265
        super().__init__()
thomwolf's avatar
thomwolf committed
266
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
267
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
268
269
270
271
272
273
274
275
276
277
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
thomwolf's avatar
thomwolf committed
278
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
279
        super().__init__()
thomwolf's avatar
thomwolf committed
280
        self.self = BertSelfAttention(config)
thomwolf's avatar
thomwolf committed
281
        self.output = BertSelfOutput(config)
282
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
283

thomwolf's avatar
thomwolf committed
284
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
285
286
        if len(heads) == 0:
            return
287
288
289
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )
290

thomwolf's avatar
thomwolf committed
291
292
293
294
        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
thomwolf's avatar
thomwolf committed
295
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
296
297

        # Update hyper params and store pruned heads
thomwolf's avatar
thomwolf committed
298
299
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
300
        self.pruned_heads = self.pruned_heads.union(heads)
thomwolf's avatar
thomwolf committed
301

302
303
304
305
306
307
308
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
309
        output_attentions=False,
310
311
    ):
        self_outputs = self.self(
312
            hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
313
        )
Rémi Louf's avatar
Rémi Louf committed
314
315
316
317
318
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


thomwolf's avatar
thomwolf committed
319
320
class BertIntermediate(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
321
        super().__init__()
thomwolf's avatar
thomwolf committed
322
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
323
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
324
325
326
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
thomwolf's avatar
thomwolf committed
327
328
329
330
331
332
333
334
335

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
336
        super().__init__()
thomwolf's avatar
thomwolf committed
337
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
338
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
339
340
341
342
343
344
345
346
347
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


Rémi Louf's avatar
Rémi Louf committed
348
class BertLayer(nn.Module):
thomwolf's avatar
thomwolf committed
349
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
350
        super().__init__()
351
        self.attention = BertAttention(config)
thomwolf's avatar
thomwolf committed
352
353
        self.is_decoder = config.is_decoder
        if self.is_decoder:
354
            self.crossattention = BertAttention(config)
Rémi Louf's avatar
Rémi Louf committed
355
356
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
357

358
359
360
361
362
363
364
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
365
        output_attentions=False,
366
    ):
367
368
369
        self_attention_outputs = self.attention(
            hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
        )
thomwolf's avatar
thomwolf committed
370
371
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
Rémi Louf's avatar
Rémi Louf committed
372

373
        if self.is_decoder and encoder_hidden_states is not None:
374
            cross_attention_outputs = self.crossattention(
375
376
377
378
379
380
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                output_attentions,
381
            )
thomwolf's avatar
thomwolf committed
382
383
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights
Rémi Louf's avatar
Rémi Louf committed
384

Rémi Louf's avatar
Rémi Louf committed
385
        intermediate_output = self.intermediate(attention_output)
Rémi Louf's avatar
Rémi Louf committed
386
        layer_output = self.output(intermediate_output, attention_output)
thomwolf's avatar
thomwolf committed
387
        outputs = (layer_output,) + outputs
Rémi Louf's avatar
Rémi Louf committed
388
        return outputs
389
390


thomwolf's avatar
thomwolf committed
391
class BertEncoder(nn.Module):
thomwolf's avatar
thomwolf committed
392
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
393
        super().__init__()
thomwolf's avatar
thomwolf committed
394
        self.output_hidden_states = config.output_hidden_states
Rémi Louf's avatar
Rémi Louf committed
395
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
thomwolf's avatar
thomwolf committed
396

397
398
399
400
401
402
403
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
404
        output_attentions=False,
405
    ):
406
407
        all_hidden_states = ()
        all_attentions = ()
408
        for i, layer_module in enumerate(self.layer):
409
            if self.output_hidden_states:
410
                all_hidden_states = all_hidden_states + (hidden_states,)
411

412
            layer_outputs = layer_module(
413
414
415
416
417
418
                hidden_states,
                attention_mask,
                head_mask[i],
                encoder_hidden_states,
                encoder_attention_mask,
                output_attentions,
419
            )
420
421
            hidden_states = layer_outputs[0]

422
            if output_attentions:
423
                all_attentions = all_attentions + (layer_outputs[1],)
424
425
426

        # Add last layer
        if self.output_hidden_states:
427
            all_hidden_states = all_hidden_states + (hidden_states,)
428

429
        outputs = (hidden_states,)
430
        if self.output_hidden_states:
431
            outputs = outputs + (all_hidden_states,)
432
        if output_attentions:
433
            outputs = outputs + (all_attentions,)
434
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
435
436
437
438


class BertPooler(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
439
        super().__init__()
thomwolf's avatar
thomwolf committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
454
        super().__init__()
thomwolf's avatar
thomwolf committed
455
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
456
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
457
458
459
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
460
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
461
462
463
464
465
466
467
468
469

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
thomwolf's avatar
thomwolf committed
470
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
471
        super().__init__()
thomwolf's avatar
thomwolf committed
472
473
474
475
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
476
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
477

thomwolf's avatar
thomwolf committed
478
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
thomwolf's avatar
thomwolf committed
479

480
481
482
        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

thomwolf's avatar
thomwolf committed
483
484
    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
Lysandre Debut's avatar
Lysandre Debut committed
485
        hidden_states = self.decoder(hidden_states)
thomwolf's avatar
thomwolf committed
486
487
488
489
        return hidden_states


class BertOnlyMLMHead(nn.Module):
thomwolf's avatar
thomwolf committed
490
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
491
        super().__init__()
thomwolf's avatar
thomwolf committed
492
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
493
494
495
496
497
498
499
500

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertOnlyNSPHead(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
501
        super().__init__()
thomwolf's avatar
thomwolf committed
502
503
504
505
506
507
508
509
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class BertPreTrainingHeads(nn.Module):
thomwolf's avatar
thomwolf committed
510
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
511
        super().__init__()
thomwolf's avatar
thomwolf committed
512
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
513
514
515
516
517
518
519
520
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


521
class BertPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
522
    """ An abstract class to handle weights initialization and
523
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
524
    """
525

526
527
528
529
    config_class = BertConfig
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"

530
531
    def _init_weights(self, module):
        """ Initialize the weights """
thomwolf's avatar
thomwolf committed
532
533
534
535
536
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
Li Dong's avatar
Li Dong committed
537
538
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
thomwolf's avatar
thomwolf committed
539
540
541
542
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


Lysandre's avatar
Lysandre committed
543
544
545
BERT_START_DOCSTRING = r"""
    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
Lysandre's avatar
Fixes  
Lysandre committed
546
    usage and behavior.
thomwolf's avatar
thomwolf committed
547

thomwolf's avatar
thomwolf committed
548
    Parameters:
Rémi Louf's avatar
Rémi Louf committed
549
        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
550
            Initializing with a config file does not load the weights associated with the model, only the configuration.
551
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
552
553
554
"""

BERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
555
    Args:
556
        input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Lysandre's avatar
Lysandre committed
557
558
            Indices of input sequence tokens in the vocabulary.

559
560
            Indices can be obtained using :class:`transformers.BertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
Lysandre's avatar
Lysandre committed
561
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
Lysandre's avatar
Lysandre committed
562

Lysandre's avatar
Lysandre committed
563
            `What are input IDs? <../glossary.html#input-ids>`__
564
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
565
566
567
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
568

Lysandre's avatar
Lysandre committed
569
            `What are attention masks? <../glossary.html#attention-mask>`__
570
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
571
572
573
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
Lysandre's avatar
Lysandre committed
574

Lysandre's avatar
Lysandre committed
575
            `What are token type IDs? <../glossary.html#token-type-ids>`_
576
        position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
577
578
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
579

Lysandre's avatar
Lysandre committed
580
581
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
582
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
583
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
Lysandre committed
584
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
Lysandre's avatar
Lysandre committed
585
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
Lysandre committed
586
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
587
588
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
Lysandre's avatar
Lysandre committed
589
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
Lysandre committed
590
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
Lysandre's avatar
Lysandre committed
591
592
            if the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
593
594
595
596
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask
            is used in the cross-attention if the model is configured as a decoder.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
597
598
        output_attentions (:obj:`bool`, `optional`, defaults to `:obj:`None`):
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
thomwolf's avatar
thomwolf committed
599
600
"""

601
602
603
604
605

@add_start_docstrings(
    "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
606
class BertModel(BertPreTrainedModel):
Lysandre's avatar
Lysandre committed
607
    """
thomwolf's avatar
thomwolf committed
608

Lysandre's avatar
Lysandre committed
609
610
611
612
    The model can behave as an encoder (with only self-attention) as well
    as a decoder, in which case a layer of cross-attention is added between
    the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
    Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
thomwolf's avatar
thomwolf committed
613

Lysandre's avatar
Lysandre committed
614
615
616
617
618
619
    To behave as an decoder the model needs to be initialized with the
    :obj:`is_decoder` argument of the configuration set to :obj:`True`; an
    :obj:`encoder_hidden_states` is expected as an input to the forward pass.

    .. _`Attention is all you need`:
        https://arxiv.org/abs/1706.03762
thomwolf's avatar
thomwolf committed
620
621

    """
622

thomwolf's avatar
thomwolf committed
623
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
624
        super().__init__(config)
625
        self.config = config
thomwolf's avatar
thomwolf committed
626

thomwolf's avatar
thomwolf committed
627
        self.embeddings = BertEmbeddings(config)
thomwolf's avatar
thomwolf committed
628
        self.encoder = BertEncoder(config)
thomwolf's avatar
thomwolf committed
629
        self.pooler = BertPooler(config)
thomwolf's avatar
thomwolf committed
630

631
        self.init_weights()
thomwolf's avatar
thomwolf committed
632

thomwolf's avatar
thomwolf committed
633
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
634
        return self.embeddings.word_embeddings
thomwolf's avatar
thomwolf committed
635

thomwolf's avatar
thomwolf committed
636
637
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
638

thomwolf's avatar
thomwolf committed
639
    def _prune_heads(self, heads_to_prune):
thomwolf's avatar
thomwolf committed
640
641
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
thomwolf's avatar
thomwolf committed
642
            See base class PreTrainedModel
thomwolf's avatar
thomwolf committed
643
644
        """
        for layer, heads in heads_to_prune.items():
645
            self.encoder.layer[layer].attention.prune_heads(heads)
thomwolf's avatar
thomwolf committed
646

647
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
648
649
650
651
652
653
654
655
656
657
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
658
        output_attentions=None,
659
    ):
Lysandre's avatar
Lysandre committed
660
661
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
662
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token)
            further processed by a Linear layer and a Tanh activation function. The Linear
            layer weights are trained from the next sentence prediction (classification)
            objective during pre-training.

            This output is usually *not* a good summary
            of the semantic content of the input, you're often better with averaging or pooling
            the sequence of hidden-states for the whole input sequence.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
677

Lysandre's avatar
Lysandre committed
678
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
679
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Lysandre's avatar
Lysandre committed
680
681
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
682

Lysandre's avatar
Lysandre committed
683
684
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
685

Lysandre's avatar
Lysandre committed
686
687
    Examples::

Lysandre's avatar
Lysandre committed
688
689
690
        from transformers import BertModel, BertTokenizer
        import torch

Lysandre's avatar
Lysandre committed
691
692
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertModel.from_pretrained('bert-base-uncased')
Lysandre's avatar
Lysandre committed
693

Lysandre's avatar
Lysandre committed
694
695
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
Lysandre's avatar
Lysandre committed
696

Lysandre's avatar
Lysandre committed
697
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
698
699

        """
700
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Lysandre's avatar
Lysandre committed
701

702
703
704
705
706
707
708
709
710
711
712
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

thomwolf's avatar
thomwolf committed
713
        if attention_mask is None:
Julien Chaumond's avatar
Julien Chaumond committed
714
            attention_mask = torch.ones(input_shape, device=device)
thomwolf's avatar
thomwolf committed
715
        if token_type_ids is None:
Julien Chaumond's avatar
Julien Chaumond committed
716
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
thomwolf's avatar
thomwolf committed
717

718
719
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
720
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
thomwolf's avatar
thomwolf committed
721

Rémi Louf's avatar
Rémi Louf committed
722
        # If a 2D ou 3D attention mask is provided for the cross-attention
Rémi Louf's avatar
Rémi Louf committed
723
        # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
724
725
726
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
727
            if encoder_attention_mask is None:
728
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
729
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
730
731
        else:
            encoder_extended_attention_mask = None
Rémi Louf's avatar
Rémi Louf committed
732

thomwolf's avatar
thomwolf committed
733
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
734
        # 1.0 in head_mask indicate we keep the head
thomwolf's avatar
thomwolf committed
735
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
736
737
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
738
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
thomwolf's avatar
thomwolf committed
739

740
741
742
743
744
745
746
747
748
        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
749
            output_attentions=output_attentions,
750
        )
751
        sequence_output = encoder_outputs[0]
thomwolf's avatar
thomwolf committed
752
        pooled_output = self.pooler(sequence_output)
753

754
755
756
        outputs = (sequence_output, pooled_output,) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
757
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
758
759


760
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
761
    """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
Lysandre's avatar
Lysandre committed
762
    a `next sentence prediction (classification)` head. """,
763
764
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
765
class BertForPreTraining(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
766
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
767
        super().__init__(config)
768

thomwolf's avatar
thomwolf committed
769
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
770
        self.cls = BertPreTrainingHeads(config)
thomwolf's avatar
thomwolf committed
771

772
        self.init_weights()
thomwolf's avatar
thomwolf committed
773

thomwolf's avatar
thomwolf committed
774
    def get_output_embeddings(self):
775
        return self.cls.predictions.decoder
thomwolf's avatar
thomwolf committed
776

777
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
778
779
780
781
782
783
784
785
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
786
        labels=None,
787
        next_sentence_label=None,
788
        output_attentions=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
789
        **kwargs
790
    ):
Lysandre's avatar
Lysandre committed
791
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
792
        labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
Lysandre's avatar
Lysandre committed
793
            Labels for computing the masked language modeling loss.
Lysandre's avatar
Lysandre committed
794
795
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
Lysandre's avatar
Lysandre committed
796
797
798
799
800
801
            in ``[0, ..., config.vocab_size]``
        next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.
Sylvain Gugger's avatar
Sylvain Gugger committed
802
803
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
Lysandre's avatar
Lysandre committed
804
805
806

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Sylvain Gugger's avatar
Sylvain Gugger committed
807
        loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Lysandre's avatar
Lysandre committed
808
809
810
            Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
811
        seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Lysandre's avatar
Lysandre committed
812
813
814
815
816
817
818
            Prediction scores of the next sequence prediction (classification) head (scores of True/False
            continuation before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
819
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Lysandre's avatar
Lysandre committed
820
821
822
823
824
825
826
827
828
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.


    Examples::

Lysandre's avatar
Lysandre committed
829
830
831
        from transformers import BertTokenizer, BertForPreTraining
        import torch

Lysandre's avatar
Lysandre committed
832
833
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForPreTraining.from_pretrained('bert-base-uncased')
Lysandre's avatar
Lysandre committed
834

Lysandre's avatar
Lysandre committed
835
836
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
Lysandre's avatar
Lysandre committed
837

Lysandre's avatar
Lysandre committed
838
839
840
        prediction_scores, seq_relationship_scores = outputs[:2]

        """
Sylvain Gugger's avatar
Sylvain Gugger committed
841
842
843
844
845
846
847
        if "masked_lm_labels" in kwargs:
            warnings.warn(
                "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                DeprecationWarning,
            )
            labels = kwargs.pop("masked_lm_labels")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
848
849
850
851
852
853
854
855

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
856
            output_attentions=output_attentions,
857
        )
858
859

        sequence_output, pooled_output = outputs[:2]
thomwolf's avatar
thomwolf committed
860
861
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

862
863
864
        outputs = (prediction_scores, seq_relationship_score,) + outputs[
            2:
        ]  # add hidden states and attention if they are here
865

Sylvain Gugger's avatar
Sylvain Gugger committed
866
        if labels is not None and next_sentence_label is not None:
LysandreJik's avatar
LysandreJik committed
867
            loss_fct = CrossEntropyLoss()
Sylvain Gugger's avatar
Sylvain Gugger committed
868
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
869
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
thomwolf's avatar
thomwolf committed
870
            total_loss = masked_lm_loss + next_sentence_loss
871
            outputs = (total_loss,) + outputs
872
873

        return outputs  # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
874
875


Sylvain Gugger's avatar
Sylvain Gugger committed
876
# TODO: Split with a different BertWithLMHead to get rid of `lm_labels` here and in encoder_decoder.
Lysandre's avatar
Lysandre committed
877
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
thomwolf's avatar
thomwolf committed
878
class BertForMaskedLM(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
879
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
880
        super().__init__(config)
thomwolf's avatar
thomwolf committed
881

thomwolf's avatar
thomwolf committed
882
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
883
        self.cls = BertOnlyMLMHead(config)
thomwolf's avatar
thomwolf committed
884

885
        self.init_weights()
thomwolf's avatar
thomwolf committed
886

thomwolf's avatar
thomwolf committed
887
    def get_output_embeddings(self):
888
        return self.cls.predictions.decoder
thomwolf's avatar
thomwolf committed
889

890
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
891
892
893
894
895
896
897
898
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
899
        labels=None,
900
901
902
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        lm_labels=None,
903
        output_attentions=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
904
        **kwargs
905
    ):
Lysandre's avatar
Lysandre committed
906
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
907
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Lysandre's avatar
Lysandre committed
908
            Labels for computing the masked language modeling loss.
Lysandre's avatar
Lysandre committed
909
910
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
Lysandre's avatar
Lysandre committed
911
912
913
            in ``[0, ..., config.vocab_size]``
        lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the left-to-right language modeling loss (next word prediction).
Lysandre's avatar
Lysandre committed
914
915
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
Lysandre's avatar
Lysandre committed
916
            in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
917
918
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
Lysandre's avatar
Lysandre committed
919
920
921

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Sylvain Gugger's avatar
Sylvain Gugger committed
922
        masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Lysandre's avatar
Lysandre committed
923
924
925
926
927
928
929
930
931
932
            Masked language modeling loss.
        ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_labels` is provided):
                Next token prediction loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
933
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Lysandre's avatar
Lysandre committed
934
935
936
937
938
939
940
941
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

        Examples::

Lysandre's avatar
Lysandre committed
942
943
944
            from transformers import BertTokenizer, BertForMaskedLM
            import torch

Lysandre's avatar
Lysandre committed
945
946
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            model = BertForMaskedLM.from_pretrained('bert-base-uncased')
Lysandre's avatar
Lysandre committed
947

Lysandre's avatar
Lysandre committed
948
            input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
Sylvain Gugger's avatar
Sylvain Gugger committed
949
            outputs = model(input_ids, labels=input_ids)
Lysandre's avatar
Lysandre committed
950

Lysandre's avatar
Lysandre committed
951
952
953
            loss, prediction_scores = outputs[:2]

        """
Sylvain Gugger's avatar
Sylvain Gugger committed
954
955
956
957
958
959
960
        if "masked_lm_labels" in kwargs:
            warnings.warn(
                "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                DeprecationWarning,
            )
            labels = kwargs.pop("masked_lm_labels")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
961
962
963
964
965
966
967
968
969
970

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
971
            output_attentions=output_attentions,
972
        )
thomwolf's avatar
thomwolf committed
973
974

        sequence_output = outputs[0]
thomwolf's avatar
thomwolf committed
975
976
        prediction_scores = self.cls(sequence_output)

wangfei's avatar
wangfei committed
977
        outputs = (prediction_scores,) + outputs[2:]  # Add hidden states and attention if they are here
978
979
980
981
982

        # Although this may seem awkward, BertForMaskedLM supports two scenarios:
        # 1. If a tensor that contains the indices of masked labels is provided,
        #    the cross-entropy is the MLM cross-entropy that measures the likelihood
        #    of predictions for masked words.
983
        # 2. If `lm_labels` is provided we are in a causal scenario where we
984
        #    try to predict the next token for each input in the decoder.
Sylvain Gugger's avatar
Sylvain Gugger committed
985
        if labels is not None:
Lysandre's avatar
Lysandre committed
986
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
Sylvain Gugger's avatar
Sylvain Gugger committed
987
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
988
            outputs = (masked_lm_loss,) + outputs
thomwolf's avatar
thomwolf committed
989

990
        if lm_labels is not None:
991
            # we are doing next-token prediction; shift prediction scores and input ids by one
Rémi Louf's avatar
Rémi Louf committed
992
            prediction_scores = prediction_scores[:, :-1, :].contiguous()
993
            lm_labels = lm_labels[:, 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
994
            loss_fct = CrossEntropyLoss()
995
            ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
996
            outputs = (ltr_lm_loss,) + outputs
997

998
        return outputs  # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
999

1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
    def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
        input_shape = input_ids.shape
        effective_batch_size = input_shape[0]

        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        # if model is does not use a causal mask then add a dummy token
        if self.config.is_decoder is False:
            assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
            attention_mask = torch.cat(
                [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
            )

            dummy_token = torch.full(
                (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
            )
            input_ids = torch.cat([input_ids, dummy_token], dim=1)

        return {"input_ids": input_ids, "attention_mask": attention_mask}

thomwolf's avatar
thomwolf committed
1022

1023
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1024
    """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
1025
)
thomwolf's avatar
thomwolf committed
1026
class BertForNextSentencePrediction(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1027
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1028
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1029

thomwolf's avatar
thomwolf committed
1030
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
1031
        self.cls = BertOnlyNSPHead(config)
thomwolf's avatar
thomwolf committed
1032

1033
        self.init_weights()
thomwolf's avatar
thomwolf committed
1034

1035
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1036
1037
1038
1039
1040
1041
1042
1043
1044
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        next_sentence_label=None,
1045
        output_attentions=None,
1046
    ):
Lysandre's avatar
Lysandre committed
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
        r"""
        next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
            Next sequence prediction (classification) loss.
1058
        seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Lysandre's avatar
Lysandre committed
1059
1060
1061
1062
1063
1064
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1065
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Lysandre's avatar
Lysandre committed
1066
1067
1068
1069
1070
1071
1072
1073
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

Lysandre's avatar
Lysandre committed
1074
1075
1076
        from transformers import BertTokenizer, BertForNextSentencePrediction
        import torch

Lysandre's avatar
Lysandre committed
1077
1078
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
Lysandre's avatar
Lysandre committed
1079

1080
1081
1082
        prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        encoding = tokenizer.encode_plus(prompt, next_sentence, return_tensors='pt')
Lysandre's avatar
Lysandre committed
1083

1084
1085
        loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
        assert logits[0, 0] < logits[0, 1] # next sentence was random
Lysandre's avatar
Lysandre committed
1086
        """
1087
1088
1089
1090
1091
1092
1093
1094

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1095
            output_attentions=output_attentions,
1096
        )
1097

thomwolf's avatar
thomwolf committed
1098
1099
        pooled_output = outputs[1]

1100
        seq_relationship_score = self.cls(pooled_output)
thomwolf's avatar
thomwolf committed
1101

1102
        outputs = (seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
1103
        if next_sentence_label is not None:
LysandreJik's avatar
LysandreJik committed
1104
            loss_fct = CrossEntropyLoss()
1105
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1106
            outputs = (next_sentence_loss,) + outputs
thomwolf's avatar
thomwolf committed
1107
1108

        return outputs  # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1109
1110


1111
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1112
    """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
Lysandre's avatar
Lysandre committed
1113
    the pooled output) e.g. for GLUE tasks. """,
1114
1115
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1116
class BertForSequenceClassification(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1117
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1118
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1119
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1120

thomwolf's avatar
thomwolf committed
1121
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
1122
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
Jordan's avatar
Jordan committed
1123
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1124

1125
        self.init_weights()
thomwolf's avatar
thomwolf committed
1126

1127
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1128
1129
1130
1131
1132
1133
1134
1135
1136
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1137
        output_attentions=None,
1138
    ):
Lysandre's avatar
Lysandre committed
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1157
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Lysandre's avatar
Lysandre committed
1158
1159
1160
1161
1162
1163
1164
1165
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

Lysandre's avatar
Lysandre committed
1166
1167
1168
        from transformers import BertTokenizer, BertForSequenceClassification
        import torch

Lysandre's avatar
Lysandre committed
1169
1170
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
Lysandre's avatar
Lysandre committed
1171

Lysandre's avatar
Lysandre committed
1172
1173
1174
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
Lysandre's avatar
Lysandre committed
1175

Lysandre's avatar
Lysandre committed
1176
1177
1178
        loss, logits = outputs[:2]

        """
1179
1180
1181
1182
1183
1184
1185
1186

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1187
            output_attentions=output_attentions,
1188
        )
1189

thomwolf's avatar
thomwolf committed
1190
1191
        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
1192
1193
1194
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1195
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
1196

thomwolf's avatar
thomwolf committed
1197
        if labels is not None:
1198
1199
1200
1201
1202
1203
1204
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1205
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1206
1207

        return outputs  # (loss), logits, (hidden_states), (attentions)
1208
1209


1210
1211
@add_start_docstrings(
    """Bert Model with a multiple choice classification head on top (a linear layer on top of
Lysandre's avatar
Lysandre committed
1212
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1213
1214
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1215
class BertForMultipleChoice(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1216
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1217
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1218

thomwolf's avatar
thomwolf committed
1219
        self.bert = BertModel(config)
1220
1221
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
thomwolf's avatar
thomwolf committed
1222

1223
        self.init_weights()
1224

1225
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1226
1227
1228
1229
1230
1231
1232
1233
1234
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1235
        output_attentions=None,
1236
    ):
Lysandre's avatar
Lysandre committed
1237
1238
1239
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the multiple choice classification loss.
1240
            Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
Lysandre's avatar
Lysandre committed
1241
1242
1243
            of the input tensors. (see `input_ids` above)

    Returns:
Lysandre's avatar
Fixes  
Lysandre committed
1244
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre Debut's avatar
Lysandre Debut committed
1245
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Lysandre's avatar
Lysandre committed
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
            Classification loss.
        classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            `num_choices` is the second dimension of the input tensors. (see `input_ids` above).

            Classification scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1256
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Lysandre's avatar
Lysandre committed
1257
1258
1259
1260
1261
1262
1263
1264
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

Lysandre's avatar
Lysandre committed
1265
1266
1267
        from transformers import BertTokenizer, BertForMultipleChoice
        import torch

Lysandre's avatar
Lysandre committed
1268
1269
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
Lysandre's avatar
Lysandre committed
1270

1271
1272
1273
1274
        prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        choice0 = "It is eaten with a fork and a knife."
        choice1 = "It is eaten while held in the hand."
        labels = torch.tensor(0) # choice0 is correct (according to Wikipedia ;))
Lysandre's avatar
Lysandre committed
1275

1276
1277
        encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
        outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
Lysandre's avatar
Lysandre committed
1278

1279
1280
        # the linear classifier still needs to be trained
        loss, logits = outputs[:2]
Lysandre's avatar
Lysandre committed
1281
        """
thomwolf's avatar
thomwolf committed
1282
1283
        num_choices = input_ids.shape[1]

1284
1285
1286
1287
1288
        input_ids = input_ids.view(-1, input_ids.size(-1))
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None

1289
1290
1291
1292
1293
1294
1295
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1296
            output_attentions=output_attentions,
1297
        )
1298

thomwolf's avatar
thomwolf committed
1299
1300
        pooled_output = outputs[1]

1301
1302
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
thomwolf's avatar
thomwolf committed
1303
        reshaped_logits = logits.view(-1, num_choices)
1304

1305
        outputs = (reshaped_logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
1306

1307
1308
1309
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)
1310
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1311
1312

        return outputs  # (loss), reshaped_logits, (hidden_states), (attentions)
1313
1314


1315
1316
@add_start_docstrings(
    """Bert Model with a token classification head on top (a linear layer on top of
Lysandre's avatar
Lysandre committed
1317
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1318
1319
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1320
class BertForTokenClassification(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1321
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1322
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1323
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1324

thomwolf's avatar
thomwolf committed
1325
        self.bert = BertModel(config)
1326
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
1327
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1328

1329
        self.init_weights()
1330

1331
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1332
1333
1334
1335
1336
1337
1338
1339
1340
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
1341
        output_attentions=None,
1342
    ):
Lysandre's avatar
Lysandre committed
1343
1344
1345
1346
1347
1348
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.

    Returns:
Lysandre's avatar
Fixes  
Lysandre committed
1349
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
1350
1351
1352
1353
1354
1355
1356
1357
1358
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
            Classification loss.
        scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
            Classification scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1359
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Lysandre's avatar
Lysandre committed
1360
1361
1362
1363
1364
1365
1366
1367
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

Lysandre's avatar
Lysandre committed
1368
1369
1370
        from transformers import BertTokenizer, BertForTokenClassification
        import torch

Lysandre's avatar
Lysandre committed
1371
1372
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForTokenClassification.from_pretrained('bert-base-uncased')
Lysandre's avatar
Lysandre committed
1373

Lysandre's avatar
Lysandre committed
1374
1375
1376
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
Lysandre's avatar
Lysandre committed
1377

Lysandre's avatar
Lysandre committed
1378
1379
1380
        loss, scores = outputs[:2]

        """
1381
1382
1383
1384
1385
1386
1387
1388

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1389
            output_attentions=output_attentions,
1390
        )
1391

thomwolf's avatar
thomwolf committed
1392
1393
        sequence_output = outputs[0]

1394
1395
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
1396

1397
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
1398
1399
        if labels is not None:
            loss_fct = CrossEntropyLoss()
1400
1401
1402
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
1403
1404
1405
1406
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
1407
1408
1409
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1410
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1411

thomwolf's avatar
thomwolf committed
1412
        return outputs  # (loss), scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1413
1414


1415
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1416
    """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
Lysandre's avatar
Lysandre committed
1417
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
1418
1419
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1420
class BertForQuestionAnswering(BertPreTrainedModel):
Lysandre's avatar
Lysandre committed
1421
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1422
        super().__init__(config)
Lysandre's avatar
Lysandre committed
1423
1424
1425
1426
1427
1428
1429
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

1430
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
Lysandre's avatar
Lysandre committed
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
1441
        output_attentions=None,
Lysandre's avatar
Lysandre committed
1442
1443
1444
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
1445
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1446
1447
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
Lysandre's avatar
Lysandre committed
1448
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
1449
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1450
1451
1452
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.

Lysandre's avatar
Lysandre committed
1453
    Returns:
Lysandre's avatar
Fixes  
Lysandre committed
1454
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
Lysandre's avatar
Lysandre committed
1455
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
thomwolf's avatar
thomwolf committed
1456
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
Lysandre's avatar
Lysandre committed
1457
        start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
thomwolf's avatar
thomwolf committed
1458
            Span-start scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1459
        end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
thomwolf's avatar
thomwolf committed
1460
            Span-end scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1461
1462
1463
1464
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
1465
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1466
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Lysandre's avatar
Lysandre committed
1467
1468
1469
1470
1471
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
thomwolf committed
1472
1473
1474

    Examples::

Lysandre's avatar
Lysandre committed
1475
1476
1477
        from transformers import BertTokenizer, BertForQuestionAnswering
        import torch

wangfei's avatar
wangfei committed
1478
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1479
        model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
Lysandre's avatar
Lysandre committed
1480

1481
        question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
1482
1483
        encoding = tokenizer.encode_plus(question, text)
        input_ids, token_type_ids = encoding["input_ids"], encoding["token_type_ids"]
1484
        start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
Lysandre's avatar
Lysandre committed
1485

1486
        all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
Lysandre's avatar
Lysandre committed
1487
1488
1489
        answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])

        assert answer == "a nice puppet"
1490

Lysandre's avatar
Lysandre committed
1491
        """
1492
1493
1494
1495
1496
1497
1498
1499

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
1500
            output_attentions=output_attentions,
1501
        )
1502

thomwolf's avatar
thomwolf committed
1503
1504
        sequence_output = outputs[0]

thomwolf's avatar
thomwolf committed
1505
1506
1507
1508
1509
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1510
        outputs = (start_logits, end_logits,) + outputs[2:]
thomwolf's avatar
thomwolf committed
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
1526
            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
1527
1528

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)