modeling_bert.py 68.3 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
16
"""PyTorch BERT model. """
thomwolf's avatar
thomwolf committed
17

thomwolf's avatar
thomwolf committed
18
from __future__ import absolute_import, division, print_function, unicode_literals
thomwolf's avatar
thomwolf committed
19
20

import logging
thomwolf's avatar
thomwolf committed
21
22
23
import math
import os
import sys
thomwolf's avatar
thomwolf committed
24
25
26

import torch
from torch import nn
27
from torch.nn import CrossEntropyLoss, MSELoss
thomwolf's avatar
thomwolf committed
28

29
30
31
from .modeling_utils import PreTrainedModel, prune_linear_layer
from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings
thomwolf's avatar
thomwolf committed
32
33
34

logger = logging.getLogger(__name__)

35
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
36
37
38
39
40
41
42
43
44
45
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
    'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
    'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
    'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
thomwolf's avatar
thomwolf committed
46
47
    'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
    'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
thomwolf's avatar
thomwolf committed
48
    'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
49
50
    'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
    'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
Julien Chaumond's avatar
Julien Chaumond committed
51
52
53
    'bert-base-japanese': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
    'bert-base-japanese-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
    'bert-base-japanese-char': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
Antti Virtanen's avatar
Antti Virtanen committed
54
    'bert-base-japanese-char-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
55
56
    'bert-base-finnish-cased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
    'bert-base-finnish-uncased-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
57
}
58

R茅mi Louf's avatar
R茅mi Louf committed
59

60
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
thomwolf's avatar
thomwolf committed
61
    """ Load tf checkpoints in a pytorch model.
62
    """
63
64
65
66
    try:
        import re
        import numpy as np
        import tensorflow as tf
thomwolf's avatar
thomwolf committed
67
    except ImportError:
Kevin Trebing's avatar
Kevin Trebing committed
68
        logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
69
70
            "https://www.tensorflow.org/install/ for installation instructions.")
        raise
71
    tf_path = os.path.abspath(tf_checkpoint_path)
thomwolf's avatar
thomwolf committed
72
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
73
74
75
76
77
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
78
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
79
80
81
82
83
84
85
86
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    for name, array in zip(names, arrays):
        name = name.split('/')
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
87
        if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
thomwolf's avatar
thomwolf committed
88
            logger.info("Skipping {}".format("/".join(name)))
89
90
91
92
93
94
95
96
97
98
99
100
101
            continue
        pointer = model
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
                l = re.split(r'_(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'kernel' or l[0] == 'gamma':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'output_bias' or l[0] == 'beta':
                pointer = getattr(pointer, 'bias')
            elif l[0] == 'output_weights':
                pointer = getattr(pointer, 'weight')
thomwolf's avatar
thomwolf committed
102
103
            elif l[0] == 'squad':
                pointer = getattr(pointer, 'classifier')
104
            else:
105
106
107
                try:
                    pointer = getattr(pointer, l[0])
                except AttributeError:
thomwolf's avatar
thomwolf committed
108
                    logger.info("Skipping {}".format("/".join(name)))
109
                    continue
110
111
112
113
114
115
116
117
118
119
120
121
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        if m_name[-11:] == '_embeddings':
            pointer = getattr(pointer, 'weight')
        elif m_name == 'kernel':
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
122
        logger.info("Initialize PyTorch weight {}".format(name))
123
124
125
126
        pointer.data = torch.from_numpy(array)
    return model


thomwolf's avatar
thomwolf committed
127
def gelu(x):
Santiago Castro's avatar
Santiago Castro committed
128
    """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
thomwolf's avatar
thomwolf committed
129
130
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
131
        Also see https://arxiv.org/abs/1606.08415
thomwolf's avatar
thomwolf committed
132
133
134
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

R茅mi Louf's avatar
R茅mi Louf committed
135

thomwolf's avatar
thomwolf committed
136
137
138
139
140
def gelu_new(x):
    """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
        Also see https://arxiv.org/abs/1606.08415
    """
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
thomwolf's avatar
thomwolf committed
141

R茅mi Louf's avatar
R茅mi Louf committed
142

thomwolf's avatar
thomwolf committed
143
144
145
146
def swish(x):
    return x * torch.sigmoid(x)


Diganta Misra's avatar
Diganta Misra committed
147
148
149
150
151
def mish(x):
    return x * torch.tanh(nn.functional.softplus(x))


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
thomwolf's avatar
thomwolf committed
152
153


154
BertLayerNorm = torch.nn.LayerNorm
thomwolf's avatar
thomwolf committed
155

R茅mi Louf's avatar
R茅mi Louf committed
156

thomwolf's avatar
thomwolf committed
157
158
159
160
161
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
162
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
163
164
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
165
166
167

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
168
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
169
170
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

171
172
173
174
175
176
177
178
    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]
        device = input_ids.device if input_ids is not None else inputs_embeds.device
thomwolf's avatar
thomwolf committed
179
        if position_ids is None:
180
181
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
thomwolf's avatar
thomwolf committed
182
        if token_type_ids is None:
Julien Chaumond's avatar
Julien Chaumond committed
183
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
thomwolf's avatar
thomwolf committed
184

185
186
        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
thomwolf's avatar
thomwolf committed
187
188
189
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

190
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
thomwolf's avatar
thomwolf committed
191
192
193
194
195
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


R茅mi Louf's avatar
R茅mi Louf committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.output_attentions = config.output_attentions

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

220
    def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
thomwolf's avatar
thomwolf committed
221
        mixed_query_layer = self.query(hidden_states)
222

223
224
225
        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
thomwolf's avatar
thomwolf committed
226
227
228
        if encoder_hidden_states is not None:
            mixed_key_layer = self.key(encoder_hidden_states)
            mixed_value_layer = self.value(encoder_hidden_states)
229
            attention_mask = encoder_attention_mask
R茅mi Louf's avatar
R茅mi Louf committed
230
        else:
thomwolf's avatar
thomwolf committed
231
232
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)
R茅mi Louf's avatar
R茅mi Louf committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
        return outputs


thomwolf's avatar
thomwolf committed
266
267
268
269
class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
270
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
271
272
273
274
275
276
277
278
279
280
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
thomwolf's avatar
thomwolf committed
281
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
282
        super(BertAttention, self).__init__()
thomwolf's avatar
thomwolf committed
283
        self.self = BertSelfAttention(config)
thomwolf's avatar
thomwolf committed
284
        self.output = BertSelfOutput(config)
285
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
286

thomwolf's avatar
thomwolf committed
287
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
288
289
        if len(heads) == 0:
            return
thomwolf's avatar
thomwolf committed
290
        mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
v_sboliu's avatar
v_sboliu committed
291
        heads = set(heads) - self.pruned_heads  # Convert to set and remove already pruned heads
thomwolf's avatar
thomwolf committed
292
        for head in heads:
293
294
            # Compute how many pruned heads are before the head and move the index accordingly
            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
thomwolf's avatar
thomwolf committed
295
296
297
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
298

thomwolf's avatar
thomwolf committed
299
300
301
302
        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
thomwolf's avatar
thomwolf committed
303
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
304
305

        # Update hyper params and store pruned heads
thomwolf's avatar
thomwolf committed
306
307
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
308
        self.pruned_heads = self.pruned_heads.union(heads)
thomwolf's avatar
thomwolf committed
309

310
311
    def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
        self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
R茅mi Louf's avatar
R茅mi Louf committed
312
313
314
315
316
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


thomwolf's avatar
thomwolf committed
317
318
319
320
class BertIntermediate(nn.Module):
    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
thomwolf's avatar
thomwolf committed
321
322
323
324
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
thomwolf's avatar
thomwolf committed
325
326
327
328
329
330
331
332
333
334
335

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super(BertOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
336
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
337
338
339
340
341
342
343
344
345
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


346
class BertLayer(nn.Module):
thomwolf's avatar
thomwolf committed
347
    def __init__(self, config):
348
        super(BertLayer, self).__init__()
349
        self.attention = BertAttention(config)
thomwolf's avatar
thomwolf committed
350
351
        self.is_decoder = config.is_decoder
        if self.is_decoder:
352
            self.crossattention = BertAttention(config)
R茅mi Louf's avatar
R茅mi Louf committed
353
354
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
355

356
    def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
thomwolf's avatar
thomwolf committed
357
358
359
        self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
360

361
362
        if self.is_decoder and encoder_hidden_states is not None:
            cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
thomwolf's avatar
thomwolf committed
363
364
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights
365

R茅mi Louf's avatar
R茅mi Louf committed
366
        intermediate_output = self.intermediate(attention_output)
R茅mi Louf's avatar
R茅mi Louf committed
367
        layer_output = self.output(intermediate_output, attention_output)
thomwolf's avatar
thomwolf committed
368
        outputs = (layer_output,) + outputs
R茅mi Louf's avatar
R茅mi Louf committed
369
        return outputs
370
371


thomwolf's avatar
thomwolf committed
372
class BertEncoder(nn.Module):
thomwolf's avatar
thomwolf committed
373
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
374
        super(BertEncoder, self).__init__()
thomwolf's avatar
thomwolf committed
375
376
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
377
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
thomwolf's avatar
thomwolf committed
378

379
    def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
380
381
        all_hidden_states = ()
        all_attentions = ()
382
        for i, layer_module in enumerate(self.layer):
383
            if self.output_hidden_states:
384
                all_hidden_states = all_hidden_states + (hidden_states,)
385

386
            layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
387
388
            hidden_states = layer_outputs[0]

thomwolf's avatar
thomwolf committed
389
            if self.output_attentions:
390
                all_attentions = all_attentions + (layer_outputs[1],)
391
392
393

        # Add last layer
        if self.output_hidden_states:
394
            all_hidden_states = all_hidden_states + (hidden_states,)
395

396
        outputs = (hidden_states,)
397
        if self.output_hidden_states:
398
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
399
        if self.output_attentions:
400
            outputs = outputs + (all_attentions,)
401
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422


class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super(BertPredictionHeadTransform, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
423
424
425
426
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
427
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
428
429
430
431
432
433
434
435
436

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
thomwolf's avatar
thomwolf committed
437
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
438
439
440
441
442
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
thomwolf's avatar
thomwolf committed
443
444
        self.decoder = nn.Linear(config.hidden_size,
                                 config.vocab_size,
thomwolf's avatar
thomwolf committed
445
                                 bias=False)
446

thomwolf's avatar
thomwolf committed
447
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
thomwolf's avatar
thomwolf committed
448
449
450
451
452
453
454
455

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states) + self.bias
        return hidden_states


class BertOnlyMLMHead(nn.Module):
thomwolf's avatar
thomwolf committed
456
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
457
        super(BertOnlyMLMHead, self).__init__()
thomwolf's avatar
thomwolf committed
458
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super(BertOnlyNSPHead, self).__init__()
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class BertPreTrainingHeads(nn.Module):
thomwolf's avatar
thomwolf committed
476
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
477
        super(BertPreTrainingHeads, self).__init__()
thomwolf's avatar
thomwolf committed
478
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
479
480
481
482
483
484
485
486
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


487
class BertPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
488
489
490
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
491
    config_class = BertConfig
492
    pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
493
494
495
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"

496
497
    def _init_weights(self, module):
        """ Initialize the weights """
thomwolf's avatar
thomwolf committed
498
499
500
501
502
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
Li Dong's avatar
Li Dong committed
503
504
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
thomwolf's avatar
thomwolf committed
505
506
507
508
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


thomwolf's avatar
thomwolf committed
509
510
511
512
513
BERT_START_DOCSTRING = r"""    The BERT model was proposed in
    `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
    by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
    pre-trained using a combination of masked language modeling objective and next sentence prediction
    on a large corpus comprising the Toronto Book Corpus and Wikipedia.
514

thomwolf's avatar
thomwolf committed
515
516
    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.
thomwolf's avatar
thomwolf committed
517

thomwolf's avatar
thomwolf committed
518
519
    .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
        https://arxiv.org/abs/1810.04805
thomwolf's avatar
thomwolf committed
520

thomwolf's avatar
thomwolf committed
521
522
    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module
523

thomwolf's avatar
thomwolf committed
524
    Parameters:
R茅mi Louf's avatar
R茅mi Louf committed
525
        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
526
            Initializing with a config file does not load the weights associated with the model, only the configuration.
527
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
528
529
530
531
532
533
534
535
536
537
538
"""

BERT_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:

            (a) For sequence pairs:

                ``tokens:         [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
R茅mi Louf's avatar
R茅mi Louf committed
539

thomwolf's avatar
thomwolf committed
540
541
542
543
544
                ``token_type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1``

            (b) For single sequences:

                ``tokens:         [CLS] the dog is hairy . [SEP]``
R茅mi Louf's avatar
R茅mi Louf committed
545

thomwolf's avatar
thomwolf committed
546
                ``token_type_ids:   0   0   0   0  0     0   0``
thomwolf's avatar
thomwolf committed
547
548
549
550

            Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.

551
552
553
            Indices can be obtained using :class:`transformers.BertTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
            :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
554
555
556
557
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
thomwolf's avatar
thomwolf committed
558
559
560
561
562
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
563
564
565
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
thomwolf's avatar
thomwolf committed
566
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
thomwolf's avatar
thomwolf committed
567
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
568
            Mask values selected in ``[0, 1]``:
thomwolf's avatar
thomwolf committed
569
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
570
571
572
573
        **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
            Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
574
575
576
577
578
579
580
581
        **encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
            is configured as a decoder.
        **encoder_attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask
            is used in the cross-attention if the model is configured as a decoder.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
thomwolf's avatar
thomwolf committed
582
583
"""

Julien Chaumond's avatar
Julien Chaumond committed
584
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
thomwolf's avatar
thomwolf committed
585
586
                      BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertModel(BertPreTrainedModel):
587
    r"""
thomwolf's avatar
thomwolf committed
588
589
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
thomwolf's avatar
thomwolf committed
590
591
592
593
594
595
596
597
            Sequence of hidden-states at the output of the last layer of the model.
        **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
            Last layer hidden-state of the first token of the sequence (classification token)
            further processed by a Linear layer and a Tanh activation function. The Linear
            layer weights are trained from the next sentence prediction (classification)
            objective during Bert pretraining. This output is usually *not* a good summary
            of the semantic content of the input, you're often better with averaging or pooling
            the sequence of hidden-states for the whole input sequence.
thomwolf's avatar
thomwolf committed
598
599
600
601
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
602
603
604
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
605
606
607

    Examples::

wangfei's avatar
wangfei committed
608
609
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertModel.from_pretrained('bert-base-uncased')
610
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
wangfei's avatar
wangfei committed
611
612
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
thomwolf's avatar
thomwolf committed
613
614

    """
thomwolf's avatar
thomwolf committed
615
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
616
        super(BertModel, self).__init__(config)
617
        self.config = config
thomwolf's avatar
thomwolf committed
618

thomwolf's avatar
thomwolf committed
619
        self.embeddings = BertEmbeddings(config)
thomwolf's avatar
thomwolf committed
620
        self.encoder = BertEncoder(config)
thomwolf's avatar
thomwolf committed
621
        self.pooler = BertPooler(config)
thomwolf's avatar
thomwolf committed
622

623
        self.init_weights()
thomwolf's avatar
thomwolf committed
624

thomwolf's avatar
thomwolf committed
625
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
626
        return self.embeddings.word_embeddings
thomwolf's avatar
thomwolf committed
627

thomwolf's avatar
thomwolf committed
628
629
    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
630

thomwolf's avatar
thomwolf committed
631
    def _prune_heads(self, heads_to_prune):
thomwolf's avatar
thomwolf committed
632
633
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
thomwolf's avatar
thomwolf committed
634
            See base class PreTrainedModel
thomwolf's avatar
thomwolf committed
635
636
        """
        for layer, heads in heads_to_prune.items():
637
            self.encoder.layer[layer].attention.prune_heads(heads)
thomwolf's avatar
thomwolf committed
638

639
640
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
                head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None):
641
642
643
644
        """ Forward pass on the Model.

        The model can behave as an encoder (with only self-attention) as well
        as a decoder, in which case a layer of cross-attention is added between
645
646
        the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
        Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
647

648
649
        To behave as an decoder the model needs to be initialized with the
        `is_decoder` argument of the configuration set to `True`; an
650
        `encoder_hidden_states` is expected as an input to the forward pass.
651

652
653
        .. _`Attention is all you need`:
            https://arxiv.org/abs/1706.03762
654
655

        """
656
657
658
659
660
661
662
663
664
665
666
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

thomwolf's avatar
thomwolf committed
667
        if attention_mask is None:
Julien Chaumond's avatar
Julien Chaumond committed
668
            attention_mask = torch.ones(input_shape, device=device)
thomwolf's avatar
thomwolf committed
669
        if token_type_ids is None:
Julien Chaumond's avatar
Julien Chaumond committed
670
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
thomwolf's avatar
thomwolf committed
671

672
673
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
674
        if attention_mask.dim() == 3:
675
            extended_attention_mask = attention_mask[:, None, :, :]
thomwolf's avatar
thomwolf committed
676
677
678
679
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
680
            if self.config.is_decoder:
681
682
                batch_size, seq_length = input_shape
                seq_ids = torch.arange(seq_length, device=device)
683
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
684
                causal_mask = causal_mask.to(torch.long)  # not converting to long will cause errors with pytorch version < 1.3
685
                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
686
687
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
thomwolf's avatar
thomwolf committed
688
689
        else:
            raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(input_shape, attention_mask.shape))
thomwolf's avatar
thomwolf committed
690
691
692
693
694
695

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
R茅mi Louf's avatar
R茅mi Louf committed
696
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
thomwolf's avatar
thomwolf committed
697
698
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

R茅mi Louf's avatar
R茅mi Louf committed
699
        # If a 2D ou 3D attention mask is provided for the cross-attention
R茅mi Louf's avatar
R茅mi Louf committed
700
        # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
701
702
703
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
704
            if encoder_attention_mask is None:
705
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
R茅mi Louf's avatar
R茅mi Louf committed
706

707
708
            if encoder_attention_mask.dim() == 3:
                encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
thomwolf's avatar
thomwolf committed
709
            elif encoder_attention_mask.dim() == 2:
710
                encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
thomwolf's avatar
thomwolf committed
711
            else:
712
713
                raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape,
                                                                                                                               encoder_attention_mask.shape))
714
715
716
717
718

            encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
            encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
        else:
            encoder_extended_attention_mask = None
R茅mi Louf's avatar
R茅mi Louf committed
719

thomwolf's avatar
thomwolf committed
720
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
721
        # 1.0 in head_mask indicate we keep the head
thomwolf's avatar
thomwolf committed
722
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
723
724
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
thomwolf's avatar
thomwolf committed
725
726
        if head_mask is not None:
            if head_mask.dim() == 1:
727
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
728
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
thomwolf's avatar
thomwolf committed
729
            elif head_mask.dim() == 2:
730
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
R茅mi Louf's avatar
R茅mi Louf committed
731
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype)  # switch to fload if need + fp16 compatibility
732
733
        else:
            head_mask = [None] * self.config.num_hidden_layers
thomwolf's avatar
thomwolf committed
734

735
        embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
736
        encoder_outputs = self.encoder(embedding_output,
thomwolf's avatar
thomwolf committed
737
738
                                       attention_mask=extended_attention_mask,
                                       head_mask=head_mask,
739
                                       encoder_hidden_states=encoder_hidden_states,
R茅mi Louf's avatar
R茅mi Louf committed
740
                                       encoder_attention_mask=encoder_extended_attention_mask)
741
        sequence_output = encoder_outputs[0]
thomwolf's avatar
thomwolf committed
742
        pooled_output = self.pooler(sequence_output)
743

744
        outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]  # add hidden_states and attentions if they are here
745
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
746
747


thomwolf's avatar
thomwolf committed
748
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
R茅mi Louf's avatar
R茅mi Louf committed
749
750
751
                       a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
                      BERT_START_DOCSTRING,
                      BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
752
class BertForPreTraining(BertPreTrainedModel):
753
    r"""
thomwolf's avatar
thomwolf committed
754
755
756
        **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
LysandreJik's avatar
LysandreJik committed
757
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
thomwolf's avatar
thomwolf committed
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
            in ``[0, ..., config.vocab_size]``
        **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
776
777
778
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
779
780
781

    Examples::

wangfei's avatar
wangfei committed
782
783
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForPreTraining.from_pretrained('bert-base-uncased')
784
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
wangfei's avatar
wangfei committed
785
786
        outputs = model(input_ids)
        prediction_scores, seq_relationship_scores = outputs[:2]
787

thomwolf's avatar
thomwolf committed
788
    """
thomwolf's avatar
thomwolf committed
789
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
790
        super(BertForPreTraining, self).__init__(config)
791

thomwolf's avatar
thomwolf committed
792
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
793
        self.cls = BertPreTrainingHeads(config)
thomwolf's avatar
thomwolf committed
794

795
        self.init_weights()
thomwolf's avatar
thomwolf committed
796

thomwolf's avatar
thomwolf committed
797
    def get_output_embeddings(self):
798
        return self.cls.predictions.decoder
thomwolf's avatar
thomwolf committed
799

800
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
801
802
803
804
805
                masked_lm_labels=None, next_sentence_label=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
R茅mi Louf's avatar
R茅mi Louf committed
806
                            position_ids=position_ids,
807
808
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds)
809
810

        sequence_output, pooled_output = outputs[:2]
thomwolf's avatar
thomwolf committed
811
812
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

813
        outputs = (prediction_scores, seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here
814

thomwolf's avatar
thomwolf committed
815
        if masked_lm_labels is not None and next_sentence_label is not None:
LysandreJik's avatar
LysandreJik committed
816
            loss_fct = CrossEntropyLoss()
817
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
818
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
thomwolf's avatar
thomwolf committed
819
            total_loss = masked_lm_loss + next_sentence_loss
820
            outputs = (total_loss,) + outputs
821
822

        return outputs  # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
823
824


thomwolf's avatar
thomwolf committed
825
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
R茅mi Louf's avatar
R茅mi Louf committed
826
827
                      BERT_START_DOCSTRING,
                      BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
828
class BertForMaskedLM(BertPreTrainedModel):
829
    r"""
thomwolf's avatar
thomwolf committed
830
831
832
        **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
LysandreJik's avatar
LysandreJik committed
833
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
thomwolf's avatar
thomwolf committed
834
            in ``[0, ..., config.vocab_size]``
835
        **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
836
837
            Labels for computing the left-to-right language modeling loss (next word prediction).
            Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
LysandreJik's avatar
LysandreJik committed
838
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
839
            in ``[0, ..., config.vocab_size]``
thomwolf's avatar
thomwolf committed
840
841

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
R茅mi Louf's avatar
R茅mi Louf committed
842
        **masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
thomwolf's avatar
thomwolf committed
843
            Masked language modeling loss.
844
        **ltr_lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
R茅mi Louf's avatar
R茅mi Louf committed
845
            Next token prediction loss.
thomwolf's avatar
thomwolf committed
846
847
848
849
850
851
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
852
853
854
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
855
856
857

    Examples::

wangfei's avatar
wangfei committed
858
859
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForMaskedLM.from_pretrained('bert-base-uncased')
860
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
wangfei's avatar
wangfei committed
861
862
        outputs = model(input_ids, masked_lm_labels=input_ids)
        loss, prediction_scores = outputs[:2]
863

thomwolf's avatar
thomwolf committed
864
    """
thomwolf's avatar
thomwolf committed
865
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
866
        super(BertForMaskedLM, self).__init__(config)
thomwolf's avatar
thomwolf committed
867

thomwolf's avatar
thomwolf committed
868
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
869
        self.cls = BertOnlyMLMHead(config)
thomwolf's avatar
thomwolf committed
870

871
        self.init_weights()
thomwolf's avatar
thomwolf committed
872

thomwolf's avatar
thomwolf committed
873
    def get_output_embeddings(self):
874
        return self.cls.predictions.decoder
thomwolf's avatar
thomwolf committed
875

876
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
877
                masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
878
879
880
881

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
R茅mi Louf's avatar
R茅mi Louf committed
882
                            position_ids=position_ids,
883
                            head_mask=head_mask,
884
                            inputs_embeds=inputs_embeds,
885
886
                            encoder_hidden_states=encoder_hidden_states,
                            encoder_attention_mask=encoder_attention_mask)
thomwolf's avatar
thomwolf committed
887
888

        sequence_output = outputs[0]
thomwolf's avatar
thomwolf committed
889
890
        prediction_scores = self.cls(sequence_output)

wangfei's avatar
wangfei committed
891
        outputs = (prediction_scores,) + outputs[2:]  # Add hidden states and attention if they are here
892
893
894
895
896

        # Although this may seem awkward, BertForMaskedLM supports two scenarios:
        # 1. If a tensor that contains the indices of masked labels is provided,
        #    the cross-entropy is the MLM cross-entropy that measures the likelihood
        #    of predictions for masked words.
897
        # 2. If `lm_labels` is provided we are in a causal scenario where we
898
        #    try to predict the next token for each input in the decoder.
thomwolf's avatar
thomwolf committed
899
        if masked_lm_labels is not None:
LysandreJik's avatar
LysandreJik committed
900
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
901
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
902
            outputs = (masked_lm_loss,) + outputs
thomwolf's avatar
thomwolf committed
903

904
        if lm_labels is not None:
905
            # we are doing next-token prediction; shift prediction scores and input ids by one
R茅mi Louf's avatar
R茅mi Louf committed
906
            prediction_scores = prediction_scores[:, :-1, :].contiguous()
907
            lm_labels = lm_labels[:, 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
908
            loss_fct = CrossEntropyLoss()
909
            ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
910
            outputs = (ltr_lm_loss,) + outputs
911

912
        return outputs  # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
913
914


thomwolf's avatar
thomwolf committed
915
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
R茅mi Louf's avatar
R茅mi Louf committed
916
917
                      BERT_START_DOCSTRING,
                      BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
918
class BertForNextSentencePrediction(BertPreTrainedModel):
919
    r"""
thomwolf's avatar
thomwolf committed
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
        **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Next sequence prediction (classification) loss.
        **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
935
936
937
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
938
939
940

    Examples::

wangfei's avatar
wangfei committed
941
942
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
943
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
wangfei's avatar
wangfei committed
944
945
        outputs = model(input_ids)
        seq_relationship_scores = outputs[0]
946

thomwolf's avatar
thomwolf committed
947
    """
thomwolf's avatar
thomwolf committed
948
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
949
        super(BertForNextSentencePrediction, self).__init__(config)
thomwolf's avatar
thomwolf committed
950

thomwolf's avatar
thomwolf committed
951
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
952
        self.cls = BertOnlyNSPHead(config)
thomwolf's avatar
thomwolf committed
953

954
        self.init_weights()
thomwolf's avatar
thomwolf committed
955

956
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
957
958
959
960
961
                next_sentence_label=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
R茅mi Louf's avatar
R茅mi Louf committed
962
                            position_ids=position_ids,
963
964
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds)
965

thomwolf's avatar
thomwolf committed
966
967
        pooled_output = outputs[1]

968
        seq_relationship_score = self.cls(pooled_output)
thomwolf's avatar
thomwolf committed
969

970
        outputs = (seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
971
        if next_sentence_label is not None:
LysandreJik's avatar
LysandreJik committed
972
            loss_fct = CrossEntropyLoss()
973
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
974
            outputs = (next_sentence_loss,) + outputs
thomwolf's avatar
thomwolf committed
975
976

        return outputs  # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
977
978


thomwolf's avatar
thomwolf committed
979
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
R茅mi Louf's avatar
R茅mi Louf committed
980
981
982
                      the pooled output) e.g. for GLUE tasks. """,
                      BERT_START_DOCSTRING,
                      BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
983
class BertForSequenceClassification(BertPreTrainedModel):
984
    r"""
thomwolf's avatar
thomwolf committed
985
986
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the sequence classification/regression loss.
LysandreJik's avatar
LysandreJik committed
987
            Indices should be in ``[0, ..., config.num_labels - 1]``.
thomwolf's avatar
thomwolf committed
988
989
990
991
992
993
994
995
996
997
998
999
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification (or regression if config.num_labels==1) loss.
        **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
1000
1001
1002
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
1003
1004
1005

    Examples::

wangfei's avatar
wangfei committed
1006
1007
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
1008
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
wangfei's avatar
wangfei committed
1009
1010
1011
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]
1012

thomwolf's avatar
thomwolf committed
1013
    """
thomwolf's avatar
thomwolf committed
1014
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
1015
        super(BertForSequenceClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
1016
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1017

thomwolf's avatar
thomwolf committed
1018
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
1019
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
1020
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
thomwolf's avatar
thomwolf committed
1021

1022
        self.init_weights()
thomwolf's avatar
thomwolf committed
1023

1024
1025
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
1026
1027
1028
1029

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
R茅mi Louf's avatar
R茅mi Louf committed
1030
                            position_ids=position_ids,
1031
1032
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds)
1033

thomwolf's avatar
thomwolf committed
1034
1035
        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
1036
1037
1038
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

1039
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
1040

thomwolf's avatar
thomwolf committed
1041
        if labels is not None:
1042
1043
1044
1045
1046
1047
1048
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1049
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1050
1051

        return outputs  # (loss), logits, (hidden_states), (attentions)
1052
1053


thomwolf's avatar
thomwolf committed
1054
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
R茅mi Louf's avatar
R茅mi Louf committed
1055
1056
1057
                      the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
                      BERT_START_DOCSTRING,
                      BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
1058
class BertForMultipleChoice(BertPreTrainedModel):
1059
    r"""
thomwolf's avatar
thomwolf committed
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss.
        **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above).
            Classification scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
1075
1076
1077
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
1078
1079
1080

    Examples::

wangfei's avatar
wangfei committed
1081
1082
1083
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
        choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
1084
        input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0)  # Batch size 1, 2 choices
wangfei's avatar
wangfei committed
1085
1086
1087
        labels = torch.tensor(1).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, classification_scores = outputs[:2]
1088

1089
    """
thomwolf's avatar
thomwolf committed
1090
    def __init__(self, config):
1091
        super(BertForMultipleChoice, self).__init__(config)
thomwolf's avatar
thomwolf committed
1092

thomwolf's avatar
thomwolf committed
1093
        self.bert = BertModel(config)
1094
1095
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
thomwolf's avatar
thomwolf committed
1096

1097
        self.init_weights()
1098

1099
1100
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
thomwolf's avatar
thomwolf committed
1101
1102
        num_choices = input_ids.shape[1]

1103
1104
1105
1106
1107
1108
1109
1110
1111
        input_ids = input_ids.view(-1, input_ids.size(-1))
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
1112
1113
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds)
1114

thomwolf's avatar
thomwolf committed
1115
1116
        pooled_output = outputs[1]

1117
1118
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
thomwolf's avatar
thomwolf committed
1119
        reshaped_logits = logits.view(-1, num_choices)
1120

1121
        outputs = (reshaped_logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
1122

1123
1124
1125
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)
1126
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1127
1128

        return outputs  # (loss), reshaped_logits, (hidden_states), (attentions)
1129
1130


thomwolf's avatar
thomwolf committed
1131
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
R茅mi Louf's avatar
R茅mi Louf committed
1132
1133
1134
                      the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
                      BERT_START_DOCSTRING,
                      BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
1135
class BertForTokenClassification(BertPreTrainedModel):
1136
    r"""
thomwolf's avatar
thomwolf committed
1137
1138
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the token classification loss.
LysandreJik's avatar
LysandreJik committed
1139
            Indices should be in ``[0, ..., config.num_labels - 1]``.
thomwolf's avatar
thomwolf committed
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss.
        **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
            Classification scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
1150
1151
1152
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
1153
1154
1155

    Examples::

wangfei's avatar
wangfei committed
1156
1157
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForTokenClassification.from_pretrained('bert-base-uncased')
1158
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
wangfei's avatar
wangfei committed
1159
1160
1161
        labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, scores = outputs[:2]
1162

1163
    """
thomwolf's avatar
thomwolf committed
1164
    def __init__(self, config):
1165
        super(BertForTokenClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
1166
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1167

thomwolf's avatar
thomwolf committed
1168
        self.bert = BertModel(config)
1169
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
1170
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1171

1172
        self.init_weights()
1173

1174
1175
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
1176
1177
1178
1179

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
R茅mi Louf's avatar
R茅mi Louf committed
1180
                            position_ids=position_ids,
1181
1182
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds)
1183

thomwolf's avatar
thomwolf committed
1184
1185
        sequence_output = outputs[0]

1186
1187
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
1188

1189
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
1190
1191
        if labels is not None:
            loss_fct = CrossEntropyLoss()
1192
1193
1194
1195
1196
1197
1198
1199
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1200
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1201

thomwolf's avatar
thomwolf committed
1202
        return outputs  # (loss), scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1203
1204


thomwolf's avatar
thomwolf committed
1205
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
R茅mi Louf's avatar
R茅mi Louf committed
1206
1207
1208
                      the hidden-states output to compute `span start logits` and `span end logits`). """,
                      BERT_START_DOCSTRING,
                      BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
1209
class BertForQuestionAnswering(BertPreTrainedModel):
1210
    r"""
thomwolf's avatar
thomwolf committed
1211
        **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
thomwolf's avatar
thomwolf committed
1212
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1213
1214
1215
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
thomwolf's avatar
thomwolf committed
1216
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-start scores (before SoftMax).
        **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-end scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
1231
1232
1233
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
1234
1235
1236

    Examples::

wangfei's avatar
wangfei committed
1237
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1238
1239
1240
1241
        model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
        question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
        input_text = "[CLS] " + question + " [SEP] " + text + " [SEP]"
        input_ids = tokenizer.encode(input_text)
1242
        token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))]
1243
        start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
1244
        all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
1245
1246
1247
        print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
        # a nice puppet

1248

thomwolf's avatar
thomwolf committed
1249
    """
thomwolf's avatar
thomwolf committed
1250
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
1251
        super(BertForQuestionAnswering, self).__init__(config)
thomwolf's avatar
thomwolf committed
1252
1253
1254
1255
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1256

1257
        self.init_weights()
thomwolf's avatar
thomwolf committed
1258

1259
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
1260
1261
1262
1263
1264
                start_positions=None, end_positions=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
R茅mi Louf's avatar
R茅mi Louf committed
1265
                            position_ids=position_ids,
1266
1267
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds)
1268

thomwolf's avatar
thomwolf committed
1269
1270
        sequence_output = outputs[0]

thomwolf's avatar
thomwolf committed
1271
1272
1273
1274
1275
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1276
        outputs = (start_logits, end_logits,) + outputs[2:]
thomwolf's avatar
thomwolf committed
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
1292
            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
1293
1294

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)