modeling_bert.py 69.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
16
"""PyTorch BERT model. """
thomwolf's avatar
thomwolf committed
17
18
19


import logging
thomwolf's avatar
thomwolf committed
20
21
import math
import os
thomwolf's avatar
thomwolf committed
22
23
24

import torch
from torch import nn
25
from torch.nn import CrossEntropyLoss, MSELoss
thomwolf's avatar
thomwolf committed
26

27
from .configuration_bert import BertConfig
Lysandre's avatar
Lysandre committed
28
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
Aymeric Augustin's avatar
Aymeric Augustin committed
29
30
from .modeling_utils import PreTrainedModel, prune_linear_layer

thomwolf's avatar
thomwolf committed
31
32
33

logger = logging.getLogger(__name__)

34
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
    "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
    "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
    "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
    "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
    "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
    "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
    "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
    "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
    "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
    "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
    "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
    "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
    "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
    "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
    "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
    "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
    "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
    "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
    "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
    "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
56
}
57

R茅mi Louf's avatar
R茅mi Louf committed
58

59
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
thomwolf's avatar
thomwolf committed
60
    """ Load tf checkpoints in a pytorch model.
61
    """
62
63
64
65
    try:
        import re
        import numpy as np
        import tensorflow as tf
thomwolf's avatar
thomwolf committed
66
    except ImportError:
67
68
69
70
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
71
        raise
72
    tf_path = os.path.abspath(tf_checkpoint_path)
thomwolf's avatar
thomwolf committed
73
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
74
75
76
77
78
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
79
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
80
81
82
83
84
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    for name, array in zip(names, arrays):
85
        name = name.split("/")
86
87
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
88
        if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
thomwolf's avatar
thomwolf committed
89
            logger.info("Skipping {}".format("/".join(name)))
90
91
92
            continue
        pointer = model
        for m_name in name:
93
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
94
                scope_names = re.split(r"_(\d+)", m_name)
95
            else:
96
97
                scope_names = [m_name]
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
98
                pointer = getattr(pointer, "weight")
99
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
100
                pointer = getattr(pointer, "bias")
101
            elif scope_names[0] == "output_weights":
102
                pointer = getattr(pointer, "weight")
103
            elif scope_names[0] == "squad":
104
                pointer = getattr(pointer, "classifier")
105
            else:
106
                try:
107
                    pointer = getattr(pointer, scope_names[0])
108
                except AttributeError:
thomwolf's avatar
thomwolf committed
109
                    logger.info("Skipping {}".format("/".join(name)))
110
                    continue
111
112
            if len(scope_names) >= 2:
                num = int(scope_names[1])
113
                pointer = pointer[num]
114
115
116
        if m_name[-11:] == "_embeddings":
            pointer = getattr(pointer, "weight")
        elif m_name == "kernel":
117
118
119
120
121
122
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
123
        logger.info("Initialize PyTorch weight {}".format(name))
124
125
126
127
        pointer.data = torch.from_numpy(array)
    return model


thomwolf's avatar
thomwolf committed
128
def gelu(x):
Santiago Castro's avatar
Santiago Castro committed
129
    """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
thomwolf's avatar
thomwolf committed
130
131
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
132
        Also see https://arxiv.org/abs/1606.08415
thomwolf's avatar
thomwolf committed
133
134
135
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

R茅mi Louf's avatar
R茅mi Louf committed
136

thomwolf's avatar
thomwolf committed
137
138
139
140
141
def gelu_new(x):
    """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
        Also see https://arxiv.org/abs/1606.08415
    """
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
thomwolf's avatar
thomwolf committed
142

R茅mi Louf's avatar
R茅mi Louf committed
143

thomwolf's avatar
thomwolf committed
144
145
146
147
def swish(x):
    return x * torch.sigmoid(x)


Diganta Misra's avatar
Diganta Misra committed
148
149
150
151
152
def mish(x):
    return x * torch.tanh(nn.functional.softplus(x))


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
thomwolf's avatar
thomwolf committed
153
154


155
BertLayerNorm = torch.nn.LayerNorm
thomwolf's avatar
thomwolf committed
156

R茅mi Louf's avatar
R茅mi Louf committed
157

thomwolf's avatar
thomwolf committed
158
159
160
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
161

thomwolf's avatar
thomwolf committed
162
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
163
        super().__init__()
164
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
165
166
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
167
168
169

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
170
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
171
172
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

173
174
175
176
177
178
179
180
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        device = input_ids.device if input_ids is not None else inputs_embeds.device
thomwolf's avatar
thomwolf committed
181
        if position_ids is None:
182
183
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
thomwolf's avatar
thomwolf committed
184
        if token_type_ids is None:
Julien Chaumond's avatar
Julien Chaumond committed
185
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
thomwolf's avatar
thomwolf committed
186

187
188
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
thomwolf's avatar
thomwolf committed
189
190
191
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

192
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
thomwolf's avatar
thomwolf committed
193
194
195
196
197
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


R茅mi Louf's avatar
R茅mi Louf committed
198
199
class BertSelfAttention(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
200
        super().__init__()
R茅mi Louf's avatar
R茅mi Louf committed
201
202
203
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
204
205
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
R茅mi Louf's avatar
R茅mi Louf committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        self.output_attentions = config.output_attentions

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

223
224
225
226
227
228
229
230
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
thomwolf's avatar
thomwolf committed
231
        mixed_query_layer = self.query(hidden_states)
232

233
234
235
        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
thomwolf's avatar
thomwolf committed
236
237
238
        if encoder_hidden_states is not None:
            mixed_key_layer = self.key(encoder_hidden_states)
            mixed_value_layer = self.value(encoder_hidden_states)
239
            attention_mask = encoder_attention_mask
R茅mi Louf's avatar
R茅mi Louf committed
240
        else:
thomwolf's avatar
thomwolf committed
241
242
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)
R茅mi Louf's avatar
R茅mi Louf committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
        return outputs


thomwolf's avatar
thomwolf committed
276
277
class BertSelfOutput(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
278
        super().__init__()
thomwolf's avatar
thomwolf committed
279
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
280
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
281
282
283
284
285
286
287
288
289
290
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
thomwolf's avatar
thomwolf committed
291
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
292
        super().__init__()
thomwolf's avatar
thomwolf committed
293
        self.self = BertSelfAttention(config)
thomwolf's avatar
thomwolf committed
294
        self.output = BertSelfOutput(config)
295
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
296

thomwolf's avatar
thomwolf committed
297
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
298
299
        if len(heads) == 0:
            return
thomwolf's avatar
thomwolf committed
300
        mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
v_sboliu's avatar
v_sboliu committed
301
        heads = set(heads) - self.pruned_heads  # Convert to set and remove already pruned heads
thomwolf's avatar
thomwolf committed
302
        for head in heads:
303
304
            # Compute how many pruned heads are before the head and move the index accordingly
            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
thomwolf's avatar
thomwolf committed
305
306
307
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
308

thomwolf's avatar
thomwolf committed
309
310
311
312
        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
thomwolf's avatar
thomwolf committed
313
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
314
315

        # Update hyper params and store pruned heads
thomwolf's avatar
thomwolf committed
316
317
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
318
        self.pruned_heads = self.pruned_heads.union(heads)
thomwolf's avatar
thomwolf committed
319

320
321
322
323
324
325
326
327
328
329
330
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
        self_outputs = self.self(
            hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
        )
R茅mi Louf's avatar
R茅mi Louf committed
331
332
333
334
335
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


thomwolf's avatar
thomwolf committed
336
337
class BertIntermediate(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
338
        super().__init__()
thomwolf's avatar
thomwolf committed
339
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
340
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
341
342
343
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
thomwolf's avatar
thomwolf committed
344
345
346
347
348
349
350
351
352

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
353
        super().__init__()
thomwolf's avatar
thomwolf committed
354
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
355
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
356
357
358
359
360
361
362
363
364
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


365
class BertLayer(nn.Module):
thomwolf's avatar
thomwolf committed
366
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
367
        super().__init__()
368
        self.attention = BertAttention(config)
thomwolf's avatar
thomwolf committed
369
370
        self.is_decoder = config.is_decoder
        if self.is_decoder:
371
            self.crossattention = BertAttention(config)
R茅mi Louf's avatar
R茅mi Louf committed
372
373
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
374

375
376
377
378
379
380
381
382
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
thomwolf's avatar
thomwolf committed
383
384
385
        self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
386

387
        if self.is_decoder and encoder_hidden_states is not None:
388
389
390
            cross_attention_outputs = self.crossattention(
                attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
            )
thomwolf's avatar
thomwolf committed
391
392
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights
393

R茅mi Louf's avatar
R茅mi Louf committed
394
        intermediate_output = self.intermediate(attention_output)
R茅mi Louf's avatar
R茅mi Louf committed
395
        layer_output = self.output(intermediate_output, attention_output)
thomwolf's avatar
thomwolf committed
396
        outputs = (layer_output,) + outputs
R茅mi Louf's avatar
R茅mi Louf committed
397
        return outputs
398
399


thomwolf's avatar
thomwolf committed
400
class BertEncoder(nn.Module):
thomwolf's avatar
thomwolf committed
401
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
402
        super().__init__()
thomwolf's avatar
thomwolf committed
403
404
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
405
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
thomwolf's avatar
thomwolf committed
406

407
408
409
410
411
412
413
414
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
415
416
        all_hidden_states = ()
        all_attentions = ()
417
        for i, layer_module in enumerate(self.layer):
418
            if self.output_hidden_states:
419
                all_hidden_states = all_hidden_states + (hidden_states,)
420

421
422
423
            layer_outputs = layer_module(
                hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
            )
424
425
            hidden_states = layer_outputs[0]

thomwolf's avatar
thomwolf committed
426
            if self.output_attentions:
427
                all_attentions = all_attentions + (layer_outputs[1],)
428
429
430

        # Add last layer
        if self.output_hidden_states:
431
            all_hidden_states = all_hidden_states + (hidden_states,)
432

433
        outputs = (hidden_states,)
434
        if self.output_hidden_states:
435
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
436
        if self.output_attentions:
437
            outputs = outputs + (all_attentions,)
438
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
439
440
441
442


class BertPooler(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
443
        super().__init__()
thomwolf's avatar
thomwolf committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
458
        super().__init__()
thomwolf's avatar
thomwolf committed
459
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
460
        if isinstance(config.hidden_act, str):
thomwolf's avatar
thomwolf committed
461
462
463
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
464
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
465
466
467
468
469
470
471
472
473

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
thomwolf's avatar
thomwolf committed
474
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
475
        super().__init__()
thomwolf's avatar
thomwolf committed
476
477
478
479
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
480
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
481

thomwolf's avatar
thomwolf committed
482
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
thomwolf's avatar
thomwolf committed
483

484
485
486
        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

thomwolf's avatar
thomwolf committed
487
488
489
490
491
492
493
    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states) + self.bias
        return hidden_states


class BertOnlyMLMHead(nn.Module):
thomwolf's avatar
thomwolf committed
494
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
495
        super().__init__()
thomwolf's avatar
thomwolf committed
496
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
497
498
499
500
501
502
503
504

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertOnlyNSPHead(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
505
        super().__init__()
thomwolf's avatar
thomwolf committed
506
507
508
509
510
511
512
513
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class BertPreTrainingHeads(nn.Module):
thomwolf's avatar
thomwolf committed
514
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
515
        super().__init__()
thomwolf's avatar
thomwolf committed
516
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
517
518
519
520
521
522
523
524
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


525
class BertPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
526
    """ An abstract class to handle weights initialization and
527
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
528
    """
529

530
    config_class = BertConfig
531
    pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
532
533
534
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"

535
536
    def _init_weights(self, module):
        """ Initialize the weights """
thomwolf's avatar
thomwolf committed
537
538
539
540
541
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
Li Dong's avatar
Li Dong committed
542
543
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
thomwolf's avatar
thomwolf committed
544
545
546
547
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


Lysandre's avatar
Lysandre committed
548
BERT_START_DOCSTRING = r"""    
thomwolf's avatar
thomwolf committed
549
550
    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.
thomwolf's avatar
thomwolf committed
551

thomwolf's avatar
thomwolf committed
552
553
    .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
        https://arxiv.org/abs/1810.04805
thomwolf's avatar
thomwolf committed
554

thomwolf's avatar
thomwolf committed
555
556
    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module
557

thomwolf's avatar
thomwolf committed
558
    Parameters:
R茅mi Louf's avatar
R茅mi Louf committed
559
        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
560
            Initializing with a config file does not load the weights associated with the model, only the configuration.
561
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
562
563
564
"""

BERT_INPUTS_DOCSTRING = r"""
Lysandre's avatar
Lysandre committed
565
566
567
568
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. 
            
569
570
            Indices can be obtained using :class:`transformers.BertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
Lysandre's avatar
Lysandre committed
571
572
573
574
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
            
            `What are input IDs? <../glossary.html#input-ids>`__
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
575
576
577
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
578
579
580
            
            `What are attention masks? <../glossary.html#attention-mask>`__
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 
thomwolf's avatar
thomwolf committed
581
582
583
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
Lysandre's avatar
Lysandre committed
584
585
586
            
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
587
588
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
589
590
591
            
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
592
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
593
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
Lysandre committed
594
595
596
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
        input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
597
598
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
Lysandre's avatar
Lysandre committed
599
600
601
602
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 
            if the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
603
604
605
606
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask
            is used in the cross-attention if the model is configured as a decoder.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
thomwolf's avatar
thomwolf committed
607
608
"""

609
610
611
612
613

@add_start_docstrings(
    "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
614
class BertModel(BertPreTrainedModel):
Lysandre's avatar
Lysandre committed
615
    """
thomwolf's avatar
thomwolf committed
616

Lysandre's avatar
Lysandre committed
617
618
619
620
    The model can behave as an encoder (with only self-attention) as well
    as a decoder, in which case a layer of cross-attention is added between
    the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
    Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
thomwolf's avatar
thomwolf committed
621

Lysandre's avatar
Lysandre committed
622
623
624
625
626
627
    To behave as an decoder the model needs to be initialized with the
    :obj:`is_decoder` argument of the configuration set to :obj:`True`; an
    :obj:`encoder_hidden_states` is expected as an input to the forward pass.

    .. _`Attention is all you need`:
        https://arxiv.org/abs/1706.03762
thomwolf's avatar
thomwolf committed
628
629

    """
630

thomwolf's avatar
thomwolf committed
631
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
632
        super().__init__(config)
633
        self.config = config
thomwolf's avatar
thomwolf committed
634

thomwolf's avatar
thomwolf committed
635
        self.embeddings = BertEmbeddings(config)
thomwolf's avatar
thomwolf committed
636
        self.encoder = BertEncoder(config)
thomwolf's avatar
thomwolf committed
637
        self.pooler = BertPooler(config)
thomwolf's avatar
thomwolf committed
638

639
        self.init_weights()
thomwolf's avatar
thomwolf committed
640

thomwolf's avatar
thomwolf committed
641
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
642
        return self.embeddings.word_embeddings
thomwolf's avatar
thomwolf committed
643

thomwolf's avatar
thomwolf committed
644
645
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
646

thomwolf's avatar
thomwolf committed
647
    def _prune_heads(self, heads_to_prune):
thomwolf's avatar
thomwolf committed
648
649
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
thomwolf's avatar
thomwolf committed
650
            See base class PreTrainedModel
thomwolf's avatar
thomwolf committed
651
652
        """
        for layer, heads in heads_to_prune.items():
653
            self.encoder.layer[layer].attention.prune_heads(heads)
thomwolf's avatar
thomwolf committed
654

Lysandre's avatar
Lysandre committed
655
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
656
657
658
659
660
661
662
663
664
665
666
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
Lysandre's avatar
Lysandre committed
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
        r"""
    Return:
        :obj:`Tuple` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token)
            further processed by a Linear layer and a Tanh activation function. The Linear
            layer weights are trained from the next sentence prediction (classification)
            objective during pre-training.

            This output is usually *not* a good summary
            of the semantic content of the input, you're often better with averaging or pooling
            the sequence of hidden-states for the whole input sequence.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
684

Lysandre's avatar
Lysandre committed
685
686
687
688
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
689

Lysandre's avatar
Lysandre committed
690
691
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
692

Lysandre's avatar
Lysandre committed
693
694
695
696
697
698
699
    Examples::

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertModel.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
700
701

        """
Lysandre's avatar
Lysandre committed
702

703
704
705
706
707
708
709
710
711
712
713
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

thomwolf's avatar
thomwolf committed
714
        if attention_mask is None:
Julien Chaumond's avatar
Julien Chaumond committed
715
            attention_mask = torch.ones(input_shape, device=device)
thomwolf's avatar
thomwolf committed
716
        if token_type_ids is None:
Julien Chaumond's avatar
Julien Chaumond committed
717
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
thomwolf's avatar
thomwolf committed
718

719
720
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
721
        if attention_mask.dim() == 3:
722
            extended_attention_mask = attention_mask[:, None, :, :]
thomwolf's avatar
thomwolf committed
723
724
725
726
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
727
            if self.config.is_decoder:
728
729
                batch_size, seq_length = input_shape
                seq_ids = torch.arange(seq_length, device=device)
730
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
731
732
733
                causal_mask = causal_mask.to(
                    torch.long
                )  # not converting to long will cause errors with pytorch version < 1.3
734
                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
735
736
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
thomwolf's avatar
thomwolf committed
737
        else:
738
739
740
741
742
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
                    input_shape, attention_mask.shape
                )
            )
thomwolf's avatar
thomwolf committed
743
744
745
746
747
748

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
R茅mi Louf's avatar
R茅mi Louf committed
749
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
thomwolf's avatar
thomwolf committed
750
751
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

R茅mi Louf's avatar
R茅mi Louf committed
752
        # If a 2D ou 3D attention mask is provided for the cross-attention
R茅mi Louf's avatar
R茅mi Louf committed
753
        # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
754
755
756
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
757
            if encoder_attention_mask is None:
758
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
R茅mi Louf's avatar
R茅mi Louf committed
759

760
761
            if encoder_attention_mask.dim() == 3:
                encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
thomwolf's avatar
thomwolf committed
762
            elif encoder_attention_mask.dim() == 2:
763
                encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
thomwolf's avatar
thomwolf committed
764
            else:
765
766
767
768
769
770
771
772
773
                raise ValueError(
                    "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(
                        encoder_hidden_shape, encoder_attention_mask.shape
                    )
                )

            encoder_extended_attention_mask = encoder_extended_attention_mask.to(
                dtype=next(self.parameters()).dtype
            )  # fp16 compatibility
774
775
776
            encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
        else:
            encoder_extended_attention_mask = None
R茅mi Louf's avatar
R茅mi Louf committed
777

thomwolf's avatar
thomwolf committed
778
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
779
        # 1.0 in head_mask indicate we keep the head
thomwolf's avatar
thomwolf committed
780
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
781
782
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
thomwolf's avatar
thomwolf committed
783
784
        if head_mask is not None:
            if head_mask.dim() == 1:
785
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
786
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
thomwolf's avatar
thomwolf committed
787
            elif head_mask.dim() == 2:
788
789
790
791
792
793
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
794
795
        else:
            head_mask = [None] * self.config.num_hidden_layers
thomwolf's avatar
thomwolf committed
796

797
798
799
800
801
802
803
804
805
806
        embedding_output = self.embeddings(
            input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
        )
807
        sequence_output = encoder_outputs[0]
thomwolf's avatar
thomwolf committed
808
        pooled_output = self.pooler(sequence_output)
809

810
811
812
        outputs = (sequence_output, pooled_output,) + encoder_outputs[
            1:
        ]  # add hidden_states and attentions if they are here
813
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
814
815


816
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
817
818
    """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and 
    a `next sentence prediction (classification)` head. """,
819
820
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
821
class BertForPreTraining(BertPreTrainedModel):
822

thomwolf's avatar
thomwolf committed
823
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
824
        super().__init__(config)
825

thomwolf's avatar
thomwolf committed
826
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
827
        self.cls = BertPreTrainingHeads(config)
thomwolf's avatar
thomwolf committed
828

829
        self.init_weights()
thomwolf's avatar
thomwolf committed
830

thomwolf's avatar
thomwolf committed
831
    def get_output_embeddings(self):
832
        return self.cls.predictions.decoder
thomwolf's avatar
thomwolf committed
833

Lysandre's avatar
Lysandre committed
834
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
835
836
837
838
839
840
841
842
843
844
845
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        masked_lm_labels=None,
        next_sentence_label=None,
    ):
Lysandre's avatar
Lysandre committed
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
        r"""
        masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``
        next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False
            continuation before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.


    Examples::

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForPreTraining.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        prediction_scores, seq_relationship_scores = outputs[:2]

        """
889
890
891
892
893
894
895
896
897

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
898
899

        sequence_output, pooled_output = outputs[:2]
thomwolf's avatar
thomwolf committed
900
901
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

902
903
904
        outputs = (prediction_scores, seq_relationship_score,) + outputs[
            2:
        ]  # add hidden states and attention if they are here
905

thomwolf's avatar
thomwolf committed
906
        if masked_lm_labels is not None and next_sentence_label is not None:
LysandreJik's avatar
LysandreJik committed
907
            loss_fct = CrossEntropyLoss()
908
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
909
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
thomwolf's avatar
thomwolf committed
910
            total_loss = masked_lm_loss + next_sentence_loss
911
            outputs = (total_loss,) + outputs
912
913

        return outputs  # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
914
915


916
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
917
    """Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING
918
)
thomwolf's avatar
thomwolf committed
919
class BertForMaskedLM(BertPreTrainedModel):
920

thomwolf's avatar
thomwolf committed
921
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
922
        super().__init__(config)
thomwolf's avatar
thomwolf committed
923

thomwolf's avatar
thomwolf committed
924
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
925
        self.cls = BertOnlyMLMHead(config)
thomwolf's avatar
thomwolf committed
926

927
        self.init_weights()
thomwolf's avatar
thomwolf committed
928

thomwolf's avatar
thomwolf committed
929
    def get_output_embeddings(self):
930
        return self.cls.predictions.decoder
thomwolf's avatar
thomwolf committed
931

Lysandre's avatar
Lysandre committed
932
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
933
934
935
936
937
938
939
940
941
942
943
944
945
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        masked_lm_labels=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        lm_labels=None,
    ):
Lysandre's avatar
Lysandre committed
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
        r"""
        masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``
        lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the left-to-right language modeling loss (next word prediction).
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Masked language modeling loss.
        ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_labels` is provided):
                Next token prediction loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

        Examples::

            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            model = BertForMaskedLM.from_pretrained('bert-base-uncased')
            input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
            outputs = model(input_ids, masked_lm_labels=input_ids)
            loss, prediction_scores = outputs[:2]

        """
987
988
989
990
991
992
993
994
995
996
997

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
thomwolf's avatar
thomwolf committed
998
999

        sequence_output = outputs[0]
thomwolf's avatar
thomwolf committed
1000
1001
        prediction_scores = self.cls(sequence_output)

wangfei's avatar
wangfei committed
1002
        outputs = (prediction_scores,) + outputs[2:]  # Add hidden states and attention if they are here
1003
1004
1005
1006
1007

        # Although this may seem awkward, BertForMaskedLM supports two scenarios:
        # 1. If a tensor that contains the indices of masked labels is provided,
        #    the cross-entropy is the MLM cross-entropy that measures the likelihood
        #    of predictions for masked words.
1008
        # 2. If `lm_labels` is provided we are in a causal scenario where we
1009
        #    try to predict the next token for each input in the decoder.
thomwolf's avatar
thomwolf committed
1010
        if masked_lm_labels is not None:
LysandreJik's avatar
LysandreJik committed
1011
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
1012
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
1013
            outputs = (masked_lm_loss,) + outputs
thomwolf's avatar
thomwolf committed
1014

1015
        if lm_labels is not None:
1016
            # we are doing next-token prediction; shift prediction scores and input ids by one
R茅mi Louf's avatar
R茅mi Louf committed
1017
            prediction_scores = prediction_scores[:, :-1, :].contiguous()
1018
            lm_labels = lm_labels[:, 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
1019
            loss_fct = CrossEntropyLoss()
1020
            ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
1021
            outputs = (ltr_lm_loss,) + outputs
1022

1023
        return outputs  # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1024
1025


1026
1027
1028
1029
@add_start_docstrings(
    """Bert Model with a `next sentence prediction (classification)` head on top. """,
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1030
class BertForNextSentencePrediction(BertPreTrainedModel):
1031

thomwolf's avatar
thomwolf committed
1032
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1033
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1034

thomwolf's avatar
thomwolf committed
1035
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
1036
        self.cls = BertOnlyNSPHead(config)
thomwolf's avatar
thomwolf committed
1037

1038
        self.init_weights()
thomwolf's avatar
thomwolf committed
1039

Lysandre's avatar
Lysandre committed
1040
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        next_sentence_label=None,
    ):
Lysandre's avatar
Lysandre committed
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
        r"""
        next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
            Next sequence prediction (classification) loss.
        seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        seq_relationship_scores = outputs[0]

        """
1085
1086
1087
1088
1089
1090
1091
1092
1093

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
1094

thomwolf's avatar
thomwolf committed
1095
1096
        pooled_output = outputs[1]

1097
        seq_relationship_score = self.cls(pooled_output)
thomwolf's avatar
thomwolf committed
1098

1099
        outputs = (seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
1100
        if next_sentence_label is not None:
LysandreJik's avatar
LysandreJik committed
1101
            loss_fct = CrossEntropyLoss()
1102
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1103
            outputs = (next_sentence_loss,) + outputs
thomwolf's avatar
thomwolf committed
1104
1105

        return outputs  # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1106
1107


1108
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1109
1110
    """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of 
    the pooled output) e.g. for GLUE tasks. """,
1111
1112
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1113
class BertForSequenceClassification(BertPreTrainedModel):
1114

thomwolf's avatar
thomwolf committed
1115
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1116
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1117
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1118

thomwolf's avatar
thomwolf committed
1119
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
1120
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
1121
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
thomwolf's avatar
thomwolf committed
1122

1123
        self.init_weights()
thomwolf's avatar
thomwolf committed
1124

Lysandre's avatar
Lysandre committed
1125
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
Lysandre's avatar
Lysandre committed
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]

        """
1171
1172
1173
1174
1175
1176
1177
1178
1179

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
1180

thomwolf's avatar
thomwolf committed
1181
1182
        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
1183
1184
1185
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1186
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
1187

thomwolf's avatar
thomwolf committed
1188
        if labels is not None:
1189
1190
1191
1192
1193
1194
1195
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1196
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1197
1198

        return outputs  # (loss), logits, (hidden_states), (attentions)
1199
1200


1201
1202
@add_start_docstrings(
    """Bert Model with a multiple choice classification head on top (a linear layer on top of
Lysandre's avatar
Lysandre committed
1203
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1204
1205
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1206
class BertForMultipleChoice(BertPreTrainedModel):
1207

thomwolf's avatar
thomwolf committed
1208
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1209
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1210

thomwolf's avatar
thomwolf committed
1211
        self.bert = BertModel(config)
1212
1213
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
thomwolf's avatar
thomwolf committed
1214

1215
        self.init_weights()
1216

Lysandre's avatar
Lysandre committed
1217
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
Lysandre's avatar
Lysandre committed
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
        loss (:obj:`torch.FloatTensor`` of shape ``(1,)`, `optional`, returned when :obj:`labels` is provided):
            Classification loss.
        classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            `num_choices` is the second dimension of the input tensors. (see `input_ids` above).

            Classification scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
        choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
        input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0)  # Batch size 1, 2 choices
        labels = torch.tensor(1).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, classification_scores = outputs[:2]

        """
thomwolf's avatar
thomwolf committed
1265
1266
        num_choices = input_ids.shape[1]

1267
1268
1269
1270
1271
        input_ids = input_ids.view(-1, input_ids.size(-1))
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None

1272
1273
1274
1275
1276
1277
1278
1279
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
1280

thomwolf's avatar
thomwolf committed
1281
1282
        pooled_output = outputs[1]

1283
1284
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
thomwolf's avatar
thomwolf committed
1285
        reshaped_logits = logits.view(-1, num_choices)
1286

1287
        outputs = (reshaped_logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
1288

1289
1290
1291
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)
1292
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1293
1294

        return outputs  # (loss), reshaped_logits, (hidden_states), (attentions)
1295
1296


1297
1298
@add_start_docstrings(
    """Bert Model with a token classification head on top (a linear layer on top of
Lysandre's avatar
Lysandre committed
1299
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1300
1301
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1302
class BertForTokenClassification(BertPreTrainedModel):
1303

thomwolf's avatar
thomwolf committed
1304
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1305
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1306
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1307

thomwolf's avatar
thomwolf committed
1308
        self.bert = BertModel(config)
1309
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
1310
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1311

1312
        self.init_weights()
1313

Lysandre's avatar
Lysandre committed
1314
    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
Lysandre's avatar
Lysandre committed
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.BertConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
            Classification loss.
        scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
            Classification scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForTokenClassification.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, scores = outputs[:2]

        """
1358
1359
1360
1361
1362
1363
1364
1365
1366

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
1367

thomwolf's avatar
thomwolf committed
1368
1369
        sequence_output = outputs[0]

1370
1371
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
1372

1373
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
1374
1375
        if labels is not None:
            loss_fct = CrossEntropyLoss()
1376
1377
1378
1379
1380
1381
1382
1383
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1384
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1385

thomwolf's avatar
thomwolf committed
1386
        return outputs  # (loss), scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1387
1388


1389
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
1390
1391
    """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
1392
1393
    BERT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1394
class BertForQuestionAnswering(BertPreTrainedModel):
Lysandre's avatar
Lysandre committed
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418

    def __init__(self, config):
        super(BertForQuestionAnswering, self).__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
1419
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1420
1421
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
Lysandre's avatar
Lysandre committed
1422
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
1423
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1424
1425
1426
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.

Lysandre's avatar
Lysandre committed
1427
1428
1429
    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
thomwolf's avatar
thomwolf committed
1430
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
Lysandre's avatar
Lysandre committed
1431
        start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
thomwolf's avatar
thomwolf committed
1432
            Span-start scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1433
        end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
thomwolf's avatar
thomwolf committed
1434
            Span-end scores (before SoftMax).
Lysandre's avatar
Lysandre committed
1435
1436
1437
1438
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
1439
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
1440
1441
1442
1443
1444
1445
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
thomwolf committed
1446
1447
1448

    Examples::

wangfei's avatar
wangfei committed
1449
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1450
1451
        model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
        question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
1452
        input_ids = tokenizer.encode(question, text)
1453
        token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))]
1454
        start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
1455
        all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
1456
1457
1458
        print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
        # a nice puppet

Lysandre's avatar
Lysandre committed
1459
        """
1460
1461
1462
1463
1464
1465
1466
1467
1468

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
1469

thomwolf's avatar
thomwolf committed
1470
1471
        sequence_output = outputs[0]

thomwolf's avatar
thomwolf committed
1472
1473
1474
1475
1476
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1477
        outputs = (start_logits, end_logits,) + outputs[2:]
thomwolf's avatar
thomwolf committed
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
1493
            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
1494
1495

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)