"vscode:/vscode.git/clone" did not exist on "aa50fd196f16c693daa0d15f53272849819bc75b"
modeling_bert.py 57.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
thomwolf's avatar
thomwolf committed
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
16
"""PyTorch BERT model. """
thomwolf's avatar
thomwolf committed
17

thomwolf's avatar
thomwolf committed
18
from __future__ import absolute_import, division, print_function, unicode_literals
thomwolf's avatar
thomwolf committed
19
20
21

import json
import logging
thomwolf's avatar
thomwolf committed
22
23
24
25
import math
import os
import sys
from io import open
thomwolf's avatar
thomwolf committed
26
27
28

import torch
from torch import nn
29
from torch.nn import CrossEntropyLoss, MSELoss
thomwolf's avatar
thomwolf committed
30

31
32
33
from .modeling_utils import PreTrainedModel, prune_linear_layer
from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings
thomwolf's avatar
thomwolf committed
34
35
36

logger = logging.getLogger(__name__)

37
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
38
39
40
41
42
43
44
45
46
47
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
    'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
    'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
    'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
thomwolf's avatar
thomwolf committed
48
49
    'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
    'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
thomwolf's avatar
thomwolf committed
50
    'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
51
}
52

53
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
thomwolf's avatar
thomwolf committed
54
    """ Load tf checkpoints in a pytorch model.
55
    """
56
57
58
59
    try:
        import re
        import numpy as np
        import tensorflow as tf
thomwolf's avatar
thomwolf committed
60
    except ImportError:
Kevin Trebing's avatar
Kevin Trebing committed
61
        logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
62
63
            "https://www.tensorflow.org/install/ for installation instructions.")
        raise
64
    tf_path = os.path.abspath(tf_checkpoint_path)
thomwolf's avatar
thomwolf committed
65
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
66
67
68
69
70
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
71
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
72
73
74
75
76
77
78
79
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    for name, array in zip(names, arrays):
        name = name.split('/')
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
80
        if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
thomwolf's avatar
thomwolf committed
81
            logger.info("Skipping {}".format("/".join(name)))
82
83
84
85
86
87
88
89
90
91
92
93
94
            continue
        pointer = model
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
                l = re.split(r'_(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'kernel' or l[0] == 'gamma':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'output_bias' or l[0] == 'beta':
                pointer = getattr(pointer, 'bias')
            elif l[0] == 'output_weights':
                pointer = getattr(pointer, 'weight')
thomwolf's avatar
thomwolf committed
95
96
            elif l[0] == 'squad':
                pointer = getattr(pointer, 'classifier')
97
            else:
98
99
100
                try:
                    pointer = getattr(pointer, l[0])
                except AttributeError:
thomwolf's avatar
thomwolf committed
101
                    logger.info("Skipping {}".format("/".join(name)))
102
                    continue
103
104
105
106
107
108
109
110
111
112
113
114
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        if m_name[-11:] == '_embeddings':
            pointer = getattr(pointer, 'weight')
        elif m_name == 'kernel':
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
115
        logger.info("Initialize PyTorch weight {}".format(name))
116
117
118
119
        pointer.data = torch.from_numpy(array)
    return model


thomwolf's avatar
thomwolf committed
120
121
122
123
def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
124
        Also see https://arxiv.org/abs/1606.08415
thomwolf's avatar
thomwolf committed
125
126
127
128
129
130
131
132
133
134
135
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def swish(x):
    return x * torch.sigmoid(x)


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}


136
137
try:
    from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
雷打不动!'s avatar
雷打不动! committed
138
except (ImportError, AttributeError) as e:
139
    logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
140
    BertLayerNorm = torch.nn.LayerNorm
thomwolf's avatar
thomwolf committed
141
142
143
144
145
146

class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
147
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
148
149
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
150
151
152

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
153
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
154
155
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

thomwolf's avatar
thomwolf committed
156
    def forward(self, input_ids, token_type_ids=None, position_ids=None):
thomwolf's avatar
thomwolf committed
157
        seq_length = input_ids.size(1)
thomwolf's avatar
thomwolf committed
158
159
160
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
thomwolf's avatar
thomwolf committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
175
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
176
177
178
179
180
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
thomwolf's avatar
thomwolf committed
181
        self.output_attentions = config.output_attentions
182

thomwolf's avatar
thomwolf committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

198
    def forward(self, hidden_states, attention_mask, head_mask=None):
thomwolf's avatar
thomwolf committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

220
221
222
223
        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

thomwolf's avatar
thomwolf committed
224
        context_layer = torch.matmul(attention_probs, value_layer)
225

thomwolf's avatar
thomwolf committed
226
227
228
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
229

230
        outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
231
        return outputs
thomwolf's avatar
thomwolf committed
232
233
234
235
236
237


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
238
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
239
240
241
242
243
244
245
246
247
248
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
thomwolf's avatar
thomwolf committed
249
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
250
        super(BertAttention, self).__init__()
thomwolf's avatar
thomwolf committed
251
        self.self = BertSelfAttention(config)
thomwolf's avatar
thomwolf committed
252
        self.output = BertSelfOutput(config)
253
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
254

thomwolf's avatar
thomwolf committed
255
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
256
257
        if len(heads) == 0:
            return
thomwolf's avatar
thomwolf committed
258
        mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
259
        heads = set(heads) - self.pruned_heads  # Convert to set and emove already pruned heads
thomwolf's avatar
thomwolf committed
260
        for head in heads:
261
262
            # Compute how many pruned heads are before the head and move the index accordingly
            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
thomwolf's avatar
thomwolf committed
263
264
265
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
266

thomwolf's avatar
thomwolf committed
267
268
269
270
        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
thomwolf's avatar
thomwolf committed
271
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
272
273

        # Update hyper params and store pruned heads
thomwolf's avatar
thomwolf committed
274
275
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
276
        self.pruned_heads = self.pruned_heads.union(heads)
thomwolf's avatar
thomwolf committed
277

278
    def forward(self, input_tensor, attention_mask, head_mask=None):
279
280
        self_outputs = self.self(input_tensor, attention_mask, head_mask)
        attention_output = self.output(self_outputs[0], input_tensor)
281
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
282
        return outputs
thomwolf's avatar
thomwolf committed
283
284
285
286
287
288


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
thomwolf's avatar
thomwolf committed
289
290
291
292
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act
thomwolf's avatar
thomwolf committed
293
294
295
296
297
298
299
300
301
302
303

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super(BertOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
304
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
305
306
307
308
309
310
311
312
313
314
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertLayer(nn.Module):
thomwolf's avatar
thomwolf committed
315
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
316
        super(BertLayer, self).__init__()
thomwolf's avatar
thomwolf committed
317
        self.attention = BertAttention(config)
thomwolf's avatar
thomwolf committed
318
319
320
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

321
    def forward(self, hidden_states, attention_mask, head_mask=None):
322
        attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
thomwolf's avatar
thomwolf committed
323
324
        attention_output = attention_outputs[0]
        intermediate_output = self.intermediate(attention_output)
thomwolf's avatar
thomwolf committed
325
        layer_output = self.output(intermediate_output, attention_output)
326
        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them
327
        return outputs
thomwolf's avatar
thomwolf committed
328
329
330


class BertEncoder(nn.Module):
thomwolf's avatar
thomwolf committed
331
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
332
        super(BertEncoder, self).__init__()
thomwolf's avatar
thomwolf committed
333
334
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
335
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
thomwolf's avatar
thomwolf committed
336

337
    def forward(self, hidden_states, attention_mask, head_mask=None):
338
339
        all_hidden_states = ()
        all_attentions = ()
340
        for i, layer_module in enumerate(self.layer):
341
            if self.output_hidden_states:
342
                all_hidden_states = all_hidden_states + (hidden_states,)
343
344
345
346

            layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
            hidden_states = layer_outputs[0]

thomwolf's avatar
thomwolf committed
347
            if self.output_attentions:
348
                all_attentions = all_attentions + (layer_outputs[1],)
349
350
351

        # Add last layer
        if self.output_hidden_states:
352
            all_hidden_states = all_hidden_states + (hidden_states,)
353

354
        outputs = (hidden_states,)
355
        if self.output_hidden_states:
356
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
357
        if self.output_attentions:
358
            outputs = outputs + (all_attentions,)
359
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380


class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super(BertPredictionHeadTransform, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
thomwolf's avatar
thomwolf committed
381
382
383
384
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
385
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
thomwolf's avatar
thomwolf committed
386
387
388
389
390
391
392
393
394

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
thomwolf's avatar
thomwolf committed
395
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
396
397
398
399
400
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
thomwolf's avatar
thomwolf committed
401
402
        self.decoder = nn.Linear(config.hidden_size,
                                 config.vocab_size,
thomwolf's avatar
thomwolf committed
403
                                 bias=False)
404

thomwolf's avatar
thomwolf committed
405
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
thomwolf's avatar
thomwolf committed
406
407
408
409
410
411
412
413

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states) + self.bias
        return hidden_states


class BertOnlyMLMHead(nn.Module):
thomwolf's avatar
thomwolf committed
414
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
415
        super(BertOnlyMLMHead, self).__init__()
thomwolf's avatar
thomwolf committed
416
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super(BertOnlyNSPHead, self).__init__()
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class BertPreTrainingHeads(nn.Module):
thomwolf's avatar
thomwolf committed
434
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
435
        super(BertPreTrainingHeads, self).__init__()
thomwolf's avatar
thomwolf committed
436
        self.predictions = BertLMPredictionHead(config)
thomwolf's avatar
thomwolf committed
437
438
439
440
441
442
443
444
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


445
class BertPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
446
447
448
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
449
    config_class = BertConfig
450
    pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
451
452
453
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "bert"

454
455
    def _init_weights(self, module):
        """ Initialize the weights """
thomwolf's avatar
thomwolf committed
456
457
458
459
460
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
Li Dong's avatar
Li Dong committed
461
462
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
thomwolf's avatar
thomwolf committed
463
464
465
466
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


thomwolf's avatar
thomwolf committed
467
468
469
470
471
BERT_START_DOCSTRING = r"""    The BERT model was proposed in
    `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
    by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
    pre-trained using a combination of masked language modeling objective and next sentence prediction
    on a large corpus comprising the Toronto Book Corpus and Wikipedia.
472

thomwolf's avatar
thomwolf committed
473
474
    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.
thomwolf's avatar
thomwolf committed
475

thomwolf's avatar
thomwolf committed
476
477
    .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
        https://arxiv.org/abs/1810.04805
thomwolf's avatar
thomwolf committed
478

thomwolf's avatar
thomwolf committed
479
480
    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module
481

thomwolf's avatar
thomwolf committed
482
    Parameters:
483
484
485
        config (:class:`~pytorch_transformers.BertConfig`): Model configuration class with all the parameters of the model. 
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
"""

BERT_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:

            (a) For sequence pairs:

                ``tokens:         [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
                
                ``token_type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1``

            (b) For single sequences:

                ``tokens:         [CLS] the dog is hairy . [SEP]``
                
                ``token_type_ids:   0   0   0   0  0     0   0``
thomwolf's avatar
thomwolf committed
505
506
507
508

            Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.

thomwolf's avatar
thomwolf committed
509
510
511
            Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
512
513
514
515
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
thomwolf's avatar
thomwolf committed
516
517
518
519
520
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
521
522
523
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
thomwolf's avatar
thomwolf committed
524
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
thomwolf's avatar
thomwolf committed
525
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
526
            Mask values selected in ``[0, 1]``:
thomwolf's avatar
thomwolf committed
527
528
529
530
531
532
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

@add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
                      BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertModel(BertPreTrainedModel):
533
    r"""
thomwolf's avatar
thomwolf committed
534
535
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
thomwolf's avatar
thomwolf committed
536
537
538
539
540
541
542
543
            Sequence of hidden-states at the output of the last layer of the model.
        **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
            Last layer hidden-state of the first token of the sequence (classification token)
            further processed by a Linear layer and a Tanh activation function. The Linear
            layer weights are trained from the next sentence prediction (classification)
            objective during Bert pretraining. This output is usually *not* a good summary
            of the semantic content of the input, you're often better with averaging or pooling
            the sequence of hidden-states for the whole input sequence.
thomwolf's avatar
thomwolf committed
544
545
546
547
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
548
549
550
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
551
552
553

    Examples::

wangfei's avatar
wangfei committed
554
555
556
557
558
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertModel.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
thomwolf's avatar
thomwolf committed
559
560

    """
thomwolf's avatar
thomwolf committed
561
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
562
        super(BertModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
563

thomwolf's avatar
thomwolf committed
564
        self.embeddings = BertEmbeddings(config)
thomwolf's avatar
thomwolf committed
565
        self.encoder = BertEncoder(config)
thomwolf's avatar
thomwolf committed
566
        self.pooler = BertPooler(config)
thomwolf's avatar
thomwolf committed
567

568
        self.init_weights()
thomwolf's avatar
thomwolf committed
569

thomwolf's avatar
thomwolf committed
570
571
572
573
    def _resize_token_embeddings(self, new_num_tokens):
        old_embeddings = self.embeddings.word_embeddings
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.embeddings.word_embeddings = new_embeddings
thomwolf's avatar
thomwolf committed
574
        return self.embeddings.word_embeddings
thomwolf's avatar
thomwolf committed
575

thomwolf's avatar
thomwolf committed
576
    def _prune_heads(self, heads_to_prune):
thomwolf's avatar
thomwolf committed
577
578
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
thomwolf's avatar
thomwolf committed
579
            See base class PreTrainedModel
thomwolf's avatar
thomwolf committed
580
581
582
583
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

584
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
thomwolf's avatar
thomwolf committed
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

thomwolf's avatar
thomwolf committed
605
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
606
        # 1.0 in head_mask indicate we keep the head
thomwolf's avatar
thomwolf committed
607
        # attention_probs has shape bsz x n_heads x N x N
thomwolf's avatar
thomwolf committed
608
609
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
thomwolf's avatar
thomwolf committed
610
611
        if head_mask is not None:
            if head_mask.dim() == 1:
612
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
613
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
thomwolf's avatar
thomwolf committed
614
            elif head_mask.dim() == 2:
615
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
thomwolf's avatar
thomwolf committed
616
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
617
618
        else:
            head_mask = [None] * self.config.num_hidden_layers
thomwolf's avatar
thomwolf committed
619

thomwolf's avatar
thomwolf committed
620
        embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
621
622
623
624
        encoder_outputs = self.encoder(embedding_output,
                                       extended_attention_mask,
                                       head_mask=head_mask)
        sequence_output = encoder_outputs[0]
thomwolf's avatar
thomwolf committed
625
        pooled_output = self.pooler(sequence_output)
626

627
        outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]  # add hidden_states and attentions if they are here
628
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
629
630


thomwolf's avatar
thomwolf committed
631
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
thomwolf's avatar
thomwolf committed
632
633
    a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
634
class BertForPreTraining(BertPreTrainedModel):
635
    r"""
thomwolf's avatar
thomwolf committed
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
        **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``
        **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
658
659
660
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
661
662
663

    Examples::

wangfei's avatar
wangfei committed
664
665
666
667
668
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForPreTraining.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        prediction_scores, seq_relationship_scores = outputs[:2]
669

thomwolf's avatar
thomwolf committed
670
    """
thomwolf's avatar
thomwolf committed
671
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
672
        super(BertForPreTraining, self).__init__(config)
673

thomwolf's avatar
thomwolf committed
674
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
675
        self.cls = BertPreTrainingHeads(config)
thomwolf's avatar
thomwolf committed
676

677
        self.init_weights()
thomwolf's avatar
thomwolf committed
678
679
680
681
682
683
        self.tie_weights()

    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
        """
thomwolf's avatar
thomwolf committed
684
685
        self._tie_or_clone_weights(self.cls.predictions.decoder,
                                   self.bert.embeddings.word_embeddings)
thomwolf's avatar
thomwolf committed
686

687
688
689
690
691
692
693
694
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                masked_lm_labels=None, next_sentence_label=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids, 
                            head_mask=head_mask)
695
696

        sequence_output, pooled_output = outputs[:2]
thomwolf's avatar
thomwolf committed
697
698
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

699
        outputs = (prediction_scores, seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here
700

thomwolf's avatar
thomwolf committed
701
702
        if masked_lm_labels is not None and next_sentence_label is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
703
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
704
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
thomwolf's avatar
thomwolf committed
705
            total_loss = masked_lm_loss + next_sentence_loss
706
            outputs = (total_loss,) + outputs
707
708

        return outputs  # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
709
710


thomwolf's avatar
thomwolf committed
711
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
thomwolf's avatar
thomwolf committed
712
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
713
class BertForMaskedLM(BertPreTrainedModel):
714
    r"""
thomwolf's avatar
thomwolf committed
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Masked language modeling loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
730
731
732
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
733
734
735

    Examples::

wangfei's avatar
wangfei committed
736
737
738
739
740
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForMaskedLM.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, masked_lm_labels=input_ids)
        loss, prediction_scores = outputs[:2]
741

thomwolf's avatar
thomwolf committed
742
    """
thomwolf's avatar
thomwolf committed
743
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
744
        super(BertForMaskedLM, self).__init__(config)
thomwolf's avatar
thomwolf committed
745

thomwolf's avatar
thomwolf committed
746
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
747
        self.cls = BertOnlyMLMHead(config)
thomwolf's avatar
thomwolf committed
748

749
        self.init_weights()
thomwolf's avatar
thomwolf committed
750
751
752
753
754
755
        self.tie_weights()

    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
        """
thomwolf's avatar
thomwolf committed
756
757
        self._tie_or_clone_weights(self.cls.predictions.decoder,
                                   self.bert.embeddings.word_embeddings)
thomwolf's avatar
thomwolf committed
758

759
760
761
762
763
764
765
766
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                masked_lm_labels=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids, 
                            head_mask=head_mask)
thomwolf's avatar
thomwolf committed
767
768

        sequence_output = outputs[0]
thomwolf's avatar
thomwolf committed
769
770
        prediction_scores = self.cls(sequence_output)

wangfei's avatar
wangfei committed
771
        outputs = (prediction_scores,) + outputs[2:]  # Add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
772
773
        if masked_lm_labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
774
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
775
            outputs = (masked_lm_loss,) + outputs
thomwolf's avatar
thomwolf committed
776
777

        return outputs  # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
778
779


thomwolf's avatar
thomwolf committed
780
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
thomwolf's avatar
thomwolf committed
781
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
782
class BertForNextSentencePrediction(BertPreTrainedModel):
783
    r"""
thomwolf's avatar
thomwolf committed
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
        **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
            Indices should be in ``[0, 1]``.
            ``0`` indicates sequence B is a continuation of sequence A,
            ``1`` indicates sequence B is a random sequence.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Next sequence prediction (classification) loss.
        **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
799
800
801
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
802
803
804

    Examples::

wangfei's avatar
wangfei committed
805
806
807
808
809
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        seq_relationship_scores = outputs[0]
810

thomwolf's avatar
thomwolf committed
811
    """
thomwolf's avatar
thomwolf committed
812
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
813
        super(BertForNextSentencePrediction, self).__init__(config)
thomwolf's avatar
thomwolf committed
814

thomwolf's avatar
thomwolf committed
815
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
816
        self.cls = BertOnlyNSPHead(config)
thomwolf's avatar
thomwolf committed
817

818
        self.init_weights()
thomwolf's avatar
thomwolf committed
819

820
821
822
823
824
825
826
827
828
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                next_sentence_label=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids, 
                            head_mask=head_mask)

thomwolf's avatar
thomwolf committed
829
830
        pooled_output = outputs[1]

831
        seq_relationship_score = self.cls(pooled_output)
thomwolf's avatar
thomwolf committed
832

833
        outputs = (seq_relationship_score,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
834
835
        if next_sentence_label is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
836
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
837
            outputs = (next_sentence_loss,) + outputs
thomwolf's avatar
thomwolf committed
838
839

        return outputs  # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
840
841


thomwolf's avatar
thomwolf committed
842
@add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
843
844
    the pooled output) e.g. for GLUE tasks. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
845
class BertForSequenceClassification(BertPreTrainedModel):
846
    r"""
thomwolf's avatar
thomwolf committed
847
848
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the sequence classification/regression loss.
LysandreJik's avatar
LysandreJik committed
849
            Indices should be in ``[0, ..., config.num_labels - 1]``.
thomwolf's avatar
thomwolf committed
850
851
852
853
854
855
856
857
858
859
860
861
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification (or regression if config.num_labels==1) loss.
        **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
862
863
864
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
865
866
867

    Examples::

wangfei's avatar
wangfei committed
868
869
870
871
872
873
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]
874

thomwolf's avatar
thomwolf committed
875
    """
thomwolf's avatar
thomwolf committed
876
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
877
        super(BertForSequenceClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
878
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
879

thomwolf's avatar
thomwolf committed
880
        self.bert = BertModel(config)
thomwolf's avatar
thomwolf committed
881
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
882
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
thomwolf's avatar
thomwolf committed
883

884
        self.init_weights()
thomwolf's avatar
thomwolf committed
885

886
887
888
889
890
891
892
893
894
    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, labels=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids, 
                            head_mask=head_mask)

thomwolf's avatar
thomwolf committed
895
896
        pooled_output = outputs[1]

thomwolf's avatar
thomwolf committed
897
898
899
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

900
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
901

thomwolf's avatar
thomwolf committed
902
        if labels is not None:
903
904
905
906
907
908
909
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
910
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
911
912

        return outputs  # (loss), logits, (hidden_states), (attentions)
913
914


thomwolf's avatar
thomwolf committed
915
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
916
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
917
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
918
class BertForMultipleChoice(BertPreTrainedModel):
919
    r"""
thomwolf's avatar
thomwolf committed
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss.
        **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above).
            Classification scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
935
936
937
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
938
939
940

    Examples::

wangfei's avatar
wangfei committed
941
942
943
944
945
946
947
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
        choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
        input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0)  # Batch size 1, 2 choices
        labels = torch.tensor(1).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, classification_scores = outputs[:2]
948

949
    """
thomwolf's avatar
thomwolf committed
950
    def __init__(self, config):
951
        super(BertForMultipleChoice, self).__init__(config)
thomwolf's avatar
thomwolf committed
952

thomwolf's avatar
thomwolf committed
953
        self.bert = BertModel(config)
954
955
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
thomwolf's avatar
thomwolf committed
956

957
        self.init_weights()
958

959
960
    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, labels=None):
thomwolf's avatar
thomwolf committed
961
962
        num_choices = input_ids.shape[1]

963
964
965
966
967
968
969
970
971
972
973
        input_ids = input_ids.view(-1, input_ids.size(-1))
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            head_mask=head_mask)

thomwolf's avatar
thomwolf committed
974
975
        pooled_output = outputs[1]

976
977
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
thomwolf's avatar
thomwolf committed
978
        reshaped_logits = logits.view(-1, num_choices)
979

980
        outputs = (reshaped_logits,) + outputs[2:]  # add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
981

982
983
984
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)
985
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
986
987

        return outputs  # (loss), reshaped_logits, (hidden_states), (attentions)
988
989


thomwolf's avatar
thomwolf committed
990
@add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
thomwolf's avatar
thomwolf committed
991
992
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
993
class BertForTokenClassification(BertPreTrainedModel):
994
    r"""
thomwolf's avatar
thomwolf committed
995
996
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the token classification loss.
LysandreJik's avatar
LysandreJik committed
997
            Indices should be in ``[0, ..., config.num_labels - 1]``.
thomwolf's avatar
thomwolf committed
998
999
1000
1001
1002
1003
1004
1005
1006
1007

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification loss.
        **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
            Classification scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
1008
1009
1010
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
1011
1012
1013

    Examples::

wangfei's avatar
wangfei committed
1014
1015
1016
1017
1018
1019
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForTokenClassification.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, scores = outputs[:2]
1020

1021
    """
thomwolf's avatar
thomwolf committed
1022
    def __init__(self, config):
1023
        super(BertForTokenClassification, self).__init__(config)
thomwolf's avatar
thomwolf committed
1024
        self.num_labels = config.num_labels
thomwolf's avatar
thomwolf committed
1025

thomwolf's avatar
thomwolf committed
1026
        self.bert = BertModel(config)
1027
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
thomwolf's avatar
thomwolf committed
1028
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1029

1030
        self.init_weights()
1031

1032
1033
1034
1035
1036
1037
1038
1039
1040
    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, labels=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids, 
                            head_mask=head_mask)

thomwolf's avatar
thomwolf committed
1041
1042
        sequence_output = outputs[0]

1043
1044
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
1045

1046
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
1047
1048
        if labels is not None:
            loss_fct = CrossEntropyLoss()
1049
1050
1051
1052
1053
1054
1055
1056
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1057
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
1058

thomwolf's avatar
thomwolf committed
1059
        return outputs  # (loss), scores, (hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
1060
1061


thomwolf's avatar
thomwolf committed
1062
@add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
thomwolf's avatar
thomwolf committed
1063
1064
    the hidden-states output to compute `span start logits` and `span end logits`). """,
    BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
1065
class BertForQuestionAnswering(BertPreTrainedModel):
1066
    r"""
thomwolf's avatar
thomwolf committed
1067
        **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
thomwolf's avatar
thomwolf committed
1068
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1069
1070
1071
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
thomwolf's avatar
thomwolf committed
1072
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
thomwolf's avatar
thomwolf committed
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-start scores (before SoftMax).
        **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-end scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
1087
1088
1089
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
1090
1091
1092

    Examples::

wangfei's avatar
wangfei committed
1093
1094
1095
1096
1097
1098
1099
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        start_positions = torch.tensor([1])
        end_positions = torch.tensor([3])
        outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
        loss, start_scores, end_scores = outputs[:2]
1100

thomwolf's avatar
thomwolf committed
1101
    """
thomwolf's avatar
thomwolf committed
1102
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
1103
        super(BertForQuestionAnswering, self).__init__(config)
thomwolf's avatar
thomwolf committed
1104
1105
1106
1107
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
thomwolf's avatar
thomwolf committed
1108

1109
        self.init_weights()
thomwolf's avatar
thomwolf committed
1110

1111
1112
1113
1114
1115
1116
1117
1118
1119
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                start_positions=None, end_positions=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids, 
                            head_mask=head_mask)

thomwolf's avatar
thomwolf committed
1120
1121
        sequence_output = outputs[0]

thomwolf's avatar
thomwolf committed
1122
1123
1124
1125
1126
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

1127
        outputs = (start_logits, end_logits,) + outputs[2:]
thomwolf's avatar
thomwolf committed
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
1143
            outputs = (total_loss,) + outputs
thomwolf's avatar
thomwolf committed
1144
1145

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)