modeling_bert.py 66 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
16
"""PyTorch BERT model. """
thomwolf's avatar
thomwolf committed
17

thomwolf's avatar
thomwolf committed
18
from __future__ import absolute_import, division, print_function, unicode_literals
thomwolf's avatar
thomwolf committed
19
20
21

import json
import logging
thomwolf's avatar
thomwolf committed
22
23
24
25
import math
import os
import sys
from io import open
thomwolf's avatar
thomwolf committed
26
27
28

import torch
from torch import nn
29
from torch.nn import CrossEntropyLoss, MSELoss
thomwolf's avatar
thomwolf committed
30

thomwolf's avatar
thomwolf committed
31
32
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel,
                             prune_linear_layer, add_start_docstrings)
thomwolf's avatar
thomwolf committed
33
34
35

logger = logging.getLogger(__name__)

36
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
37
38
39
40
41
42
43
44
45
46
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
    'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
    'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
    'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
thomwolf's avatar
thomwolf committed
47
48
    'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
    'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
thomwolf's avatar
thomwolf committed
49
    'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
50
}
51

52
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
53
54
55
56
57
58
59
60
61
62
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
    'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
    'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
    'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
thomwolf's avatar
thomwolf committed
63
64
65
    'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
    'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
    'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
thomwolf's avatar
thomwolf committed
66
67
}

thomwolf's avatar
thomwolf committed
68

69
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
thomwolf's avatar
thomwolf committed
70
    """ Load tf checkpoints in a pytorch model.
71
    """
72
73
74
75
    try:
        import re
        import numpy as np
        import tensorflow as tf
thomwolf's avatar
thomwolf committed
76
    except ImportError:
thomwolf's avatar
thomwolf committed
77
        logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
78
79
            "https://www.tensorflow.org/install/ for installation instructions.")
        raise
80
    tf_path = os.path.abspath(tf_checkpoint_path)
thomwolf's avatar
thomwolf committed
81
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
82
83
84
85
86
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
87
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
88
89
90
91
92
93
94
95
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    for name, array in zip(names, arrays):
        name = name.split('/')
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
96
        if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
thomwolf's avatar
thomwolf committed
97
            logger.info("Skipping {}".format("/".join(name)))
98
99
100
101
102
103
104
105
106
107
108
109
110
            continue
        pointer = model
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
                l = re.split(r'_(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'kernel' or l[0] == 'gamma':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'output_bias' or l[0] == 'beta':
                pointer = getattr(pointer, 'bias')
            elif l[0] == 'output_weights':
                pointer = getattr(pointer, 'weight')
thomwolf's avatar
thomwolf committed
111
112
            elif l[0] == 'squad':
                pointer = getattr(pointer, 'classifier')
113
            else:
114
115
116
                try:
                    pointer = getattr(pointer, l[0])
                except AttributeError:
thomwolf's avatar
thomwolf committed
117
                    logger.info("Skipping {}".format("/".join(name)))
118
                    continue
119
120
121
122
123
124
125
126
127
128
129
130
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        if m_name[-11:] == '_embeddings':
            pointer = getattr(pointer, 'weight')
        elif m_name == 'kernel':
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
131
        logger.info("Initialize PyTorch weight {}".format(name))
132
133
134
135
        pointer.data = torch.from_numpy(array)
    return model


thomwolf's avatar
thomwolf committed
136
137
138
139
def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
140
        Also see https://arxiv.org/abs/1606.08415
thomwolf's avatar
thomwolf committed
141
142
143
144
145
146
147
148
149
150
151
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def swish(x):
    return x * torch.sigmoid(x)


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}


152
class BertConfig(PretrainedConfig):
153
    r"""
154
        :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a
155
        `BertModel`.
156

157
        Arguments:
thomwolf's avatar
thomwolf committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
            hidden_size: Size of the encoder layers and the pooler layer.
            num_hidden_layers: Number of hidden layers in the Transformer encoder.
            num_attention_heads: Number of attention heads for each attention layer in
                the Transformer encoder.
            intermediate_size: The size of the "intermediate" (i.e., feed-forward)
                layer in the Transformer encoder.
            hidden_act: The non-linear activation function (function or string) in the
                encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
            hidden_dropout_prob: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attention_probs_dropout_prob: The dropout ratio for the attention
                probabilities.
            max_position_embeddings: The maximum sequence length that this model might
                ever be used with. Typically set this to something large just in case
                (e.g., 512 or 1024 or 2048).
            type_vocab_size: The vocabulary size of the `token_type_ids` passed into
                `BertModel`.
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
178
            layer_norm_eps: The epsilon used by LayerNorm.
179
    """
thomwolf's avatar
thomwolf committed
180
    pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

    def __init__(self,
                 vocab_size_or_config_json_file=30522,
                 hidden_size=768,
                 num_hidden_layers=12,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 hidden_act="gelu",
                 hidden_dropout_prob=0.1,
                 attention_probs_dropout_prob=0.1,
                 max_position_embeddings=512,
                 type_vocab_size=2,
                 initializer_range=0.02,
                 layer_norm_eps=1e-12,
                 **kwargs):
        """Constructs BertConfig.
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

        Arguments:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
            hidden_size: Size of the encoder layers and the pooler layer.
            num_hidden_layers: Number of hidden layers in the Transformer encoder.
            num_attention_heads: Number of attention heads for each attention layer in
                the Transformer encoder.
            intermediate_size: The size of the "intermediate" (i.e., feed-forward)
                layer in the Transformer encoder.
            hidden_act: The non-linear activation function (function or string) in the
                encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
            hidden_dropout_prob: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attention_probs_dropout_prob: The dropout ratio for the attention
                probabilities.
            max_position_embeddings: The maximum sequence length that this model might
                ever be used with. Typically set this to something large just in case
                (e.g., 512 or 1024 or 2048).
            type_vocab_size: The vocabulary size of the `token_type_ids` passed into
                `BertModel`.
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
            layer_norm_eps: The epsilon used by LayerNorm.
thomwolf's avatar
thomwolf committed
220
        """
thomwolf's avatar
thomwolf committed
221
        super(BertConfig, self).__init__(**kwargs)
thomwolf's avatar
thomwolf committed
222
223
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
224
            with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
thomwolf's avatar
thomwolf committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.hidden_act = hidden_act
            self.intermediate_size = intermediate_size
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.type_vocab_size = type_vocab_size
            self.initializer_range = initializer_range
240
            self.layer_norm_eps = layer_norm_eps
thomwolf's avatar
thomwolf committed
241
242
243
244
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
                             "or the path to a pretrained model config file (str)")

245

246

247
248
249
try:
    from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError:
250
    logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    class BertLayerNorm(nn.Module):
        def __init__(self, hidden_size, eps=1e-12):
            """Construct a layernorm module in the TF style (epsilon inside the square root).
            """
            super(BertLayerNorm, self).__init__()
            self.weight = nn.Parameter(torch.ones(hidden_size))
            self.bias = nn.Parameter(torch.zeros(hidden_size))
            self.variance_epsilon = eps

        def forward(self, x):
            u = x.mean(-1, keepdim=True)
            s = (x - u).pow(2).mean(-1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.variance_epsilon)
            return self.weight * x + self.bias
thomwolf's avatar
thomwolf committed
265
266
267
268
269
270

class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
271
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
272
273
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
274
275
276

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
277
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
278
279
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

thomwolf's avatar
thomwolf committed
280
    def forward(self, input_ids, position_ids=None, token_type_ids=None):
thomwolf's avatar
thomwolf committed
281
        seq_length = input_ids.size(1)
thomwolf's avatar
thomwolf committed
282
283
284
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
thomwolf's avatar
thomwolf committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
299
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
300
301
302
303
304
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
thomwolf's avatar
thomwolf committed
305
        self.output_attentions = config.output_attentions
306

thomwolf's avatar
thomwolf committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

322
    def forward(self, hidden_states, attention_mask, head_mask=None):
thomwolf's avatar
thomwolf committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

344
345
346
347
        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

thomwolf's avatar
thomwolf committed
348
        context_layer = torch.matmul(attention_probs, value_layer)
349

thomwolf's avatar
thomwolf committed
350
351
352
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
353

354
        outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
355
        return outputs
thomwolf's avatar
thomwolf committed
356
357
358
359
360
361


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
362
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
363
364
365
366
367
368
369
370
371
372
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
thomwolf's avatar
thomwolf committed
373
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
374
        super(BertAttention, self).__init__()
thomwolf's avatar
thomwolf committed
375
        self.self = BertSelfAttention(config)
thomwolf's avatar
thomwolf committed
376
377
        self.output = BertSelfOutput(config)

thomwolf's avatar
thomwolf committed
378
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
379
380
        if len(heads) == 0:
            return
thomwolf's avatar
thomwolf committed
381
        mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
thomwolf's avatar
thomwolf committed
382
383
384
385
386
387
388
389
        for head in heads:
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
thomwolf's avatar
thomwolf committed
390
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
thomwolf's avatar
thomwolf committed
391
392
393
394
        # Update hyper params
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads

395
    def forward(self, input_tensor, attention_mask, head_mask=None):
396
397
        self_outputs = self.self(input_tensor, attention_mask, head_mask)
        attention_output = self.output(self_outputs[0], input_tensor)
398
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
399
        return outputs
thomwolf's avatar
thomwolf committed
400
401
402
403
404
405


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
thomwolf's avatar
thomwolf committed
406
407
408
409
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
thomwolf's avatar
thomwolf committed
410
411
412
413
414
415
416
417
418
419
420

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super(BertOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
421
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
422
423
424
425
426
427
428
429
430
431
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertLayer(nn.Module):
thomwolf's avatar
thomwolf committed
432
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
433
        super(BertLayer, self).__init__()
thomwolf's avatar
thomwolf committed
434
        self.attention = BertAttention(config)
thomwolf's avatar
thomwolf committed
435
436
437
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

438
    def forward(self, hidden_states, attention_mask, head_mask=None):
439
        attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
thomwolf's avatar
thomwolf committed
440
441
        attention_output = attention_outputs[0]
        intermediate_output = self.intermediate(attention_output)
thomwolf's avatar
thomwolf committed
442
        layer_output = self.output(intermediate_output, attention_output)
443
        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them
444
        return outputs
thomwolf's avatar
thomwolf committed
445
446
447


class BertEncoder(nn.Module):
thomwolf's avatar
thomwolf committed
448
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
449
        super(BertEncoder, self).__init__()
thomwolf's avatar
thomwolf committed
450
451
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
452
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
thomwolf's avatar
thomwolf committed
453

454
    def forward(self, hidden_states, attention_mask, head_mask=None):
455
456
        all_hidden_states = ()
        all_attentions = ()
457
        for i, layer_module in enumerate(self.layer):
458
            if self.output_hidden_states:
459
                all_hidden_states = all_hidden_states + (hidden_states,)
460
461
462
463

            layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
            hidden_states = layer_outputs[0]

thomwolf's avatar
thomwolf committed
464
            if self.output_attentions:
465
                all_attentions = all_attentions + (layer_outputs[1],)
466
467
468

        # Add last layer
        if self.output_hidden_states:
469
            all_hidden_states = all_hidden_states + (hidden_states,)
470

471
        outputs = (hidden_states,)
472
        if self.output_hidden_states:
473
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
474
        if self.output_attentions:
475
            outputs = outputs + (all_attentions,)
476
        return outputs  # outputs, (hidden states), (attentions)
thomwolf's avatar
thomwolf committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497


class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super(BertPredictionHeadTransform, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
498
499
500
501
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
502
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
503
504
505
506
507
508
509
510
511

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
thomwolf's avatar
thomwolf committed
512
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
513
514
515
516
517
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
thomwolf's avatar
thomwolf committed
518
519
        self.decoder = nn.Linear(config.hidden_size,
                                 config.vocab_size,
thomwolf's avatar
thomwolf committed
520
                                 bias=False)
521

thomwolf's avatar
thomwolf committed
522
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
thomwolf's avatar
thomwolf committed
523
524
525
526
527
528
529
530

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states) + self.bias
        return hidden_states


class BertOnlyMLMHead(nn.Module):
thomwolf's avatar
thomwolf committed
531
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
532
        super(BertOnlyMLMHead, self).__init__()
thomwolf's avatar
thomwolf committed
533
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super(BertOnlyNSPHead, self).__init__()
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class BertPreTrainingHeads(nn.Module):
thomwolf's avatar
thomwolf committed
551
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
552
        super(BertPreTrainingHeads, self).__init__()
thomwolf's avatar
thomwolf committed
553
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
554
555
556
557
558
559
560
561
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


562
class BertPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
563
564
565
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
566
    config_class = BertConfig
567
    pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
568
569
570
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"

571
572
573
    def __init__(self, *inputs, **kwargs):
        super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)

thomwolf's avatar
thomwolf committed
574
    def init_weights(self, module):
thomwolf's avatar
thomwolf committed
575
576
577
578
579
580
581
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
Li Dong's avatar
Li Dong committed
582
583
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
thomwolf's avatar
thomwolf committed
584
585
586
587
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


thomwolf's avatar
thomwolf committed
588
589
590
591
592
BERT_START_DOCSTRING = r"""    The BERT model was proposed in
    `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
    by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
    pre-trained using a combination of masked language modeling objective and next sentence prediction
    on a large corpus comprising the Toronto Book Corpus and Wikipedia.
593

thomwolf's avatar
thomwolf committed
594
595
    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.
thomwolf's avatar
thomwolf committed
596

thomwolf's avatar
thomwolf committed
597
598
    .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
        https://arxiv.org/abs/1810.04805
thomwolf's avatar
thomwolf committed
599

thomwolf's avatar
thomwolf committed
600
601
    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module
602

thomwolf's avatar
thomwolf committed
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
    Parameters:
        config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model.
"""

BERT_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:

            (a) For sequence pairs:

                ``tokens:         [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
                
                ``token_type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1``

            (b) For single sequences:

                ``tokens:         [CLS] the dog is hairy . [SEP]``
                
                ``token_type_ids:   0   0   0   0  0     0   0``
    
            Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
thomwolf's avatar
thomwolf committed
628
629
630
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1[``.
thomwolf's avatar
thomwolf committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
        **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask indices selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
        **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask indices selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
                      BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertModel(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
649
    __doc__ = r"""
thomwolf's avatar
thomwolf committed
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

    Examples::

        >>> config = BertConfig.from_pretrained('bert-base-uncased')
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> model = BertModel(config)
        >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        >>> outputs = model(input_ids)
        >>> last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
thomwolf's avatar
thomwolf committed
669
670

    """
thomwolf's avatar
thomwolf committed
671
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
672
        super(BertModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
673

thomwolf's avatar
thomwolf committed
674
        self.embeddings = BertEmbeddings(config)
thomwolf's avatar
thomwolf committed
675
        self.encoder = BertEncoder(config)
thomwolf's avatar
thomwolf committed
676
        self.pooler = BertPooler(config)
thomwolf's avatar
thomwolf committed
677

thomwolf's avatar
thomwolf committed
678
        self.apply(self.init_weights)
thomwolf's avatar
thomwolf committed
679

thomwolf's avatar
thomwolf committed
680
681
682
683
    def _resize_token_embeddings(self, new_num_tokens):
        old_embeddings = self.embeddings.word_embeddings
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.embeddings.word_embeddings = new_embeddings
thomwolf's avatar
thomwolf committed
684
        return self.embeddings.word_embeddings
thomwolf's avatar
thomwolf committed
685

thomwolf's avatar
thomwolf committed
686
    def _prune_heads(self, heads_to_prune):
thomwolf's avatar
thomwolf committed
687
688
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
thomwolf's avatar
thomwolf committed
689
            See base class PreTrainedModel
thomwolf's avatar
thomwolf committed
690
691
692
693
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

thomwolf's avatar
thomwolf committed
694
    def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

thomwolf's avatar
thomwolf committed
715
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
716
        # 1.0 in head_mask indicate we keep the head
thomwolf's avatar
thomwolf committed
717
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
718
719
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
thomwolf's avatar
thomwolf committed
720
721
        if head_mask is not None:
            if head_mask.dim() == 1:
722
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
723
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
thomwolf's avatar
thomwolf committed
724
            elif head_mask.dim() == 2:
725
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
thomwolf's avatar
thomwolf committed
726
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
727
728
        else:
            head_mask = [None] * self.config.num_hidden_layers
thomwolf's avatar
thomwolf committed
729

thomwolf's avatar
thomwolf committed
730
        embedding_output = self.embeddings(input_ids, position_ids, token_type_ids)
731
732
733
734
        encoder_outputs = self.encoder(embedding_output,
                                       extended_attention_mask,
                                       head_mask=head_mask)
        sequence_output = encoder_outputs[0]
thomwolf's avatar
thomwolf committed
735
        pooled_output = self.pooler(sequence_output)
736

737
        outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]  # add hidden_states and attentions if they are here
738
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
739
740


thomwolf's avatar
thomwolf committed
741
742
743
@add_start_docstrings("""Bert Model transformer BERT model with two heads on top as done during the pre-training:
    a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
744
class BertForPreTraining(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
745
    __doc__ = r"""
thomwolf's avatar
thomwolf committed
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
        **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``
        **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

    Examples::

        >>> config = BertConfig.from_pretrained('bert-base-uncased')
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> 
        >>> model = BertForPreTraining(config)
        >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        >>> outputs = model(input_ids)
thomwolf's avatar
thomwolf committed
780
        >>> prediction_scores, seq_relationship_scores = outputs[:2]
781

thomwolf's avatar
thomwolf committed
782
    """
thomwolf's avatar
thomwolf committed
783
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
784
        super(BertForPreTraining, self).__init__(config)
785

thomwolf's avatar
thomwolf committed
786
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
787
        self.cls = BertPreTrainingHeads(config)
thomwolf's avatar
thomwolf committed
788

thomwolf's avatar
thomwolf committed
789
        self.apply(self.init_weights)
thomwolf's avatar
thomwolf committed
790
791
792
793
794
795
        self.tie_weights()

    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
        """
thomwolf's avatar
thomwolf committed
796
797
        self._tie_or_clone_weights(self.cls.predictions.decoder,
                                   self.bert.embeddings.word_embeddings)
thomwolf's avatar
thomwolf committed
798

thomwolf's avatar
thomwolf committed
799
    def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
800
                next_sentence_label=None, head_mask=None):
thomwolf's avatar
thomwolf committed
801
        outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
802
803

        sequence_output, pooled_output = outputs[:2]
thomwolf's avatar
thomwolf committed
804
805
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

806
        outputs = (prediction_scores, seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here
807

thomwolf's avatar
thomwolf committed
808
809
        if masked_lm_labels is not None and next_sentence_label is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
810
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
811
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
thomwolf's avatar
thomwolf committed
812
            total_loss = masked_lm_loss + next_sentence_loss
813
            outputs = (total_loss,) + outputs
814
815

        return outputs  # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
816
817


thomwolf's avatar
thomwolf committed
818
819
@add_start_docstrings("""Bert Model transformer BERT model with a `language modeling` head on top. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
820
class BertForMaskedLM(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
821
    __doc__ = r"""
thomwolf's avatar
thomwolf committed
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
        **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Masked language modeling loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

    Examples::

        >>> config = BertConfig.from_pretrained('bert-base-uncased')
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> 
        >>> model = BertForMaskedLM(config)
        >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        >>> outputs = model(input_ids, masked_lm_labels=input_ids)
thomwolf's avatar
thomwolf committed
849
        >>> loss, prediction_scores = outputs[:2]
850

thomwolf's avatar
thomwolf committed
851
    """
thomwolf's avatar
thomwolf committed
852
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
853
        super(BertForMaskedLM, self).__init__(config)
thomwolf's avatar
thomwolf committed
854

thomwolf's avatar
thomwolf committed
855
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
856
        self.cls = BertOnlyMLMHead(config)
thomwolf's avatar
thomwolf committed
857

thomwolf's avatar
thomwolf committed
858
        self.apply(self.init_weights)
thomwolf's avatar
thomwolf committed
859
860
861
862
863
864
        self.tie_weights()

    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
        """
thomwolf's avatar
thomwolf committed
865
866
        self._tie_or_clone_weights(self.cls.predictions.decoder,
                                   self.bert.embeddings.word_embeddings)
thomwolf's avatar
thomwolf committed
867

thomwolf's avatar
thomwolf committed
868
869
    def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
        outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
870
871

        sequence_output = outputs[0]
thomwolf's avatar
thomwolf committed
872
873
        prediction_scores = self.cls(sequence_output)

874
        outputs = (prediction_scores,) + outputs[2:]  # Add hidden states and attention is they are here
thomwolf's avatar
thomwolf committed
875
876
        if masked_lm_labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
877
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
878
            outputs = (masked_lm_loss,) + outputs
thomwolf's avatar
thomwolf committed
879
880

        return outputs  # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
881
882


thomwolf's avatar
thomwolf committed
883
884
@add_start_docstrings("""Bert Model transformer BERT model with a `next sentence prediction (classification)` head on top. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
885
class BertForNextSentencePrediction(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
886
    __doc__ = r"""
thomwolf's avatar
thomwolf committed
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
        **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Next sequence prediction (classification) loss.
        **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

    Examples::

        >>> config = BertConfig.from_pretrained('bert-base-uncased')
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> 
        >>> model = BertForNextSentencePrediction(config)
        >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        >>> outputs = model(input_ids)
        >>> seq_relationship_scores = outputs[0]
915

thomwolf's avatar
thomwolf committed
916
    """
thomwolf's avatar
thomwolf committed
917
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
918
        super(BertForNextSentencePrediction, self).__init__(config)
thomwolf's avatar
thomwolf committed
919

thomwolf's avatar
thomwolf committed
920
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
921
        self.cls = BertOnlyNSPHead(config)
thomwolf's avatar
thomwolf committed
922

thomwolf's avatar
thomwolf committed
923
        self.apply(self.init_weights)
thomwolf's avatar
thomwolf committed
924

thomwolf's avatar
thomwolf committed
925
926
    def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
        outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
927
928
        pooled_output = outputs[1]

929
        seq_relationship_score = self.cls(pooled_output)
thomwolf's avatar
thomwolf committed
930

931
        outputs = (seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
932
933
        if next_sentence_label is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
934
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
935
            outputs = (next_sentence_loss,) + outputs
thomwolf's avatar
thomwolf committed
936
937

        return outputs  # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
938
939


thomwolf's avatar
thomwolf committed
940
941
942
@add_start_docstrings("""Bert Model transformer BERT model with a sequence classification/regression head on top (a linear layer on top of
    the pooled output) e.g. for GLUE tasks. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
943
class BertForSequenceClassification(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
944
    __doc__ = r"""
thomwolf's avatar
thomwolf committed
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the sequence classification/regression loss.
            Indices should be in ``[0, ..., config.num_labels]``.
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification (or regression if config.num_labels==1) loss.
        **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

    Examples::

        >>> config = BertConfig.from_pretrained('bert-base-uncased')
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> 
        >>> model = BertForSequenceClassification(config)
        >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        >>> labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        >>> outputs = model(input_ids, labels=labels)
thomwolf's avatar
thomwolf committed
973
        >>> loss, logits = outputs[:2]
974

thomwolf's avatar
thomwolf committed
975
    """
thomwolf's avatar
thomwolf committed
976
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
977
        super(BertForSequenceClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
978
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
979

thomwolf's avatar
thomwolf committed
980
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
981
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
982
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
thomwolf's avatar
thomwolf committed
983

thomwolf's avatar
thomwolf committed
984
        self.apply(self.init_weights)
thomwolf's avatar
thomwolf committed
985

thomwolf's avatar
thomwolf committed
986
987
    def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
        outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
988
989
        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
990
991
992
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

993
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
994

thomwolf's avatar
thomwolf committed
995
        if labels is not None:
996
997
998
999
1000
1001
1002
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1003
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1004
1005

        return outputs  # (loss), logits, (hidden_states), (attentions)
1006
1007


thomwolf's avatar
thomwolf committed
1008
1009
1010
@add_start_docstrings("""Bert Model transformer BERT model with a multiple choice classification head on top (a linear layer on top of
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
    BERT_START_DOCSTRING)
thomwolf's avatar
thomwolf committed
1011
class BertForMultipleChoice(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1012
    __doc__ = r"""
thomwolf's avatar
thomwolf committed
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            The second dimension of the input (`num_choices`) indicates the number of choices to score.
            To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:

            (a) For sequence pairs:

                ``tokens:         [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
                
                ``token_type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1``

            (b) For single sequences:

                ``tokens:         [CLS] the dog is hairy . [SEP]``
                
                ``token_type_ids:   0   0   0   0  0     0   0``
    
            Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
            Segment token indices to indicate first and second portions of the inputs.
            The second dimension of the input (`num_choices`) indicates the number of choices to score.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
        **attention_mask**: (`optional`) ``torch.Tensor`` of shape ``(batch_size, num_choices, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            The second dimension of the input (`num_choices`) indicates the number of choices to score.
            Mask indices selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
        **head_mask**: (`optional`) ``torch.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask indices selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss.
        **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above).
            Classification scores (before SoftMax).
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

    Examples::

        >>> config = BertConfig.from_pretrained('bert-base-uncased')
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> 
        >>> model = BertForMultipleChoice(config)
        >>> choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
        >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0)  # Batch size 1, 2 choices
        >>> labels = torch.tensor(1).unsqueeze(0)  # Batch size 1
        >>> outputs = model(input_ids, labels=labels)
thomwolf's avatar
thomwolf committed
1078
        >>> loss, classification_scores = outputs[:2]
1079

1080
    """
thomwolf's avatar
thomwolf committed
1081
    def __init__(self, config):
1082
        super(BertForMultipleChoice, self).__init__(config)
thomwolf's avatar
thomwolf committed
1083

thomwolf's avatar
thomwolf committed
1084
        self.bert = BertModel(config)
1085
1086
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
thomwolf's avatar
thomwolf committed
1087

thomwolf's avatar
thomwolf committed
1088
        self.apply(self.init_weights)
1089

thomwolf's avatar
thomwolf committed
1090
    def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
thomwolf's avatar
thomwolf committed
1091
1092
        num_choices = input_ids.shape[1]

1093
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
thomwolf's avatar
thomwolf committed
1094
        flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1095
1096
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
thomwolf's avatar
thomwolf committed
1097
        outputs = self.bert(flat_input_ids, flat_position_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
1098
1099
        pooled_output = outputs[1]

1100
1101
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
thomwolf's avatar
thomwolf committed
1102
        reshaped_logits = logits.view(-1, num_choices)
1103

1104
        outputs = (reshaped_logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
1105

1106
1107
1108
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)
1109
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1110
1111

        return outputs  # (loss), reshaped_logits, (hidden_states), (attentions)
1112
1113


thomwolf's avatar
thomwolf committed
1114
1115
1116
@add_start_docstrings("""Bert Model transformer BERT model with a token classification head on top (a linear layer on top of
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
1117
class BertForTokenClassification(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1118
    __doc__ = r"""
thomwolf's avatar
thomwolf committed
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the token classification loss.
            Indices should be in ``[0, ..., config.num_labels]``.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss.
        **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
            Classification scores (before SoftMax).
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

    Examples::

        >>> config = BertConfig.from_pretrained('bert-base-uncased')
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> 
        >>> model = BertForTokenClassification(config)
        >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        >>> labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0)  # Batch size 1
        >>> outputs = model(input_ids, labels=labels)
thomwolf's avatar
thomwolf committed
1145
        >>> loss, scores = outputs[:2]
1146

1147
    """
thomwolf's avatar
thomwolf committed
1148
    def __init__(self, config):
1149
        super(BertForTokenClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
1150
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1151

thomwolf's avatar
thomwolf committed
1152
        self.bert = BertModel(config)
1153
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
1154
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1155

thomwolf's avatar
thomwolf committed
1156
        self.apply(self.init_weights)
1157

thomwolf's avatar
thomwolf committed
1158
1159
    def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
        outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
1160
1161
        sequence_output = outputs[0]

1162
1163
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
1164

1165
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
1166
1167
        if labels is not None:
            loss_fct = CrossEntropyLoss()
1168
1169
1170
1171
1172
1173
1174
1175
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1176
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1177

thomwolf's avatar
thomwolf committed
1178
        return outputs  # (loss), scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1179
1180


thomwolf's avatar
thomwolf committed
1181
1182
1183
@add_start_docstrings("""Bert Model transformer BERT model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
1184
class BertForQuestionAnswering(BertPreTrainedModel):
thomwolf's avatar
thomwolf committed
1185
    __doc__ = r"""
thomwolf's avatar
thomwolf committed
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
        **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-start scores (before SoftMax).
        **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-end scores (before SoftMax).
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.

    Examples::

        >>> config = BertConfig.from_pretrained('bert-base-uncased')
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> 
        >>> model = BertForQuestionAnswering(config)
        >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        >>> start_positions = torch.tensor([1])
        >>> end_positions = torch.tensor([3])
        >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
        >>> loss, start_scores, end_scores = outputs[:2]
1221

thomwolf's avatar
thomwolf committed
1222
    """
thomwolf's avatar
thomwolf committed
1223
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
1224
        super(BertForQuestionAnswering, self).__init__(config)
thomwolf's avatar
thomwolf committed
1225
1226
1227
1228
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1229

thomwolf's avatar
thomwolf committed
1230
        self.apply(self.init_weights)
thomwolf's avatar
thomwolf committed
1231

thomwolf's avatar
thomwolf committed
1232
    def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, start_positions=None,
thomwolf's avatar
thomwolf committed
1233
                end_positions=None, head_mask=None):
thomwolf's avatar
thomwolf committed
1234
        outputs = self.bert(input_ids, position_ids, token_type_ids, attention_mask, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
1235
1236
        sequence_output = outputs[0]

thomwolf's avatar
thomwolf committed
1237
1238
1239
1240
1241
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1242
        outputs = (start_logits, end_logits,) + outputs[2:]
thomwolf's avatar
thomwolf committed
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
1258
            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
1259
1260

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)