modeling_gpt2.py 33.3 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

18
19
from __future__ import absolute_import, division, print_function, unicode_literals

thomwolf's avatar
thomwolf committed
20
21
22
23
24
25
26
27
28
29
30
31
32
import collections
import json
import logging
import math
import os
import sys
from io import open

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter

33
34
35
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings
thomwolf's avatar
thomwolf committed
36
37
38

logger = logging.getLogger(__name__)

39
40
41
42
43
44
45
46
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
    "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
    "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
    "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-pytorch_model.bin",
    "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",
}

thomwolf's avatar
thomwolf committed
47

48
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
thomwolf's avatar
thomwolf committed
49
50
51
52
53
54
55
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import re
        import numpy as np
        import tensorflow as tf
    except ImportError:
56
57
58
59
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
60
61
        raise
    tf_path = os.path.abspath(gpt2_checkpoint_path)
thomwolf's avatar
thomwolf committed
62
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
thomwolf's avatar
thomwolf committed
63
64
65
66
67
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
68
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
thomwolf's avatar
thomwolf committed
69
70
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
thomwolf's avatar
thomwolf committed
71
        arrays.append(array.squeeze())
thomwolf's avatar
thomwolf committed
72
73

    for name, array in zip(names, arrays):
thomwolf's avatar
thomwolf committed
74
        name = name[6:]  # skip "model/"
75
        name = name.split("/")
thomwolf's avatar
thomwolf committed
76
77
        pointer = model
        for m_name in name:
78
79
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
                l = re.split(r"(\d+)", m_name)
thomwolf's avatar
thomwolf committed
80
81
            else:
                l = [m_name]
82
83
84
85
86
            if l[0] == "w" or l[0] == "g":
                pointer = getattr(pointer, "weight")
            elif l[0] == "b":
                pointer = getattr(pointer, "bias")
            elif l[0] == "wpe" or l[0] == "wte":
thomwolf's avatar
thomwolf committed
87
                pointer = getattr(pointer, l[0])
88
                pointer = getattr(pointer, "weight")
thomwolf's avatar
thomwolf committed
89
90
91
92
93
94
95
96
97
98
            else:
                pointer = getattr(pointer, l[0])
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
99
        logger.info("Initialize PyTorch weight {}".format(name))
thomwolf's avatar
thomwolf committed
100
101
102
103
104
105
106
107
108
        pointer.data = torch.from_numpy(array)
    return model


def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
109
    def __init__(self, nx, n_ctx, config, scale=False):
thomwolf's avatar
thomwolf committed
110
        super(Attention, self).__init__()
thomwolf's avatar
thomwolf committed
111
112
        self.output_attentions = config.output_attentions

thomwolf's avatar
thomwolf committed
113
114
115
116
117
118
119
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
120

thomwolf's avatar
thomwolf committed
121
122
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
123
124
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
125
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
126

127
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
128
129
        if len(heads) == 0:
            return
130
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
131
        heads = set(heads) - self.pruned_heads  # Convert to set and emove already pruned heads
132
        for head in heads:
133
134
            # Compute how many pruned heads are before the head and move the index accordingly
            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
135
136
137
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
138
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
139

140
141
142
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
143

144
145
146
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
147
        self.pruned_heads = self.pruned_heads.union(heads)
148

149
    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
150
151
152
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
thomwolf's avatar
thomwolf committed
153
        nd, ns = w.size(-2), w.size(-1)
154
        b = self.bias[:, :, ns - nd : ns, :ns]
155
        w = w * b - 1e4 * (1 - b)
thomwolf's avatar
thomwolf committed
156

157
158
159
160
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

thomwolf's avatar
thomwolf committed
161
        w = nn.Softmax(dim=-1)(w)
162
        w = self.attn_dropout(w)
163
164
165
166
167

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
168
        outputs = [torch.matmul(w, v)]
thomwolf's avatar
thomwolf committed
169
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
170
171
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
172
173
174
175
176
177
178
179
180
181

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
thomwolf's avatar
thomwolf committed
182
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
thomwolf's avatar
thomwolf committed
183
        else:
thomwolf's avatar
thomwolf committed
184
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
thomwolf's avatar
thomwolf committed
185

186
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
187
188
189
190
191
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
thomwolf's avatar
thomwolf committed
192
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
193
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
thomwolf's avatar
thomwolf committed
194
            key = torch.cat((past_key, key), dim=-1)
thomwolf's avatar
thomwolf committed
195
            value = torch.cat((past_value, value), dim=-2)
thomwolf's avatar
thomwolf committed
196
        present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
197

198
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
thomwolf's avatar
thomwolf committed
199
        a = attn_outputs[0]
200

thomwolf's avatar
thomwolf committed
201
202
        a = self.merge_heads(a)
        a = self.c_proj(a)
203
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
204
205
206

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)
thomwolf's avatar
thomwolf committed
207
208
209
210
211
212
213
214
215


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
        super(MLP, self).__init__()
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
        self.act = gelu
216
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
217
218
219
220

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
221
        return self.dropout(h2)
thomwolf's avatar
thomwolf committed
222
223
224


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
225
    def __init__(self, n_ctx, config, scale=False):
thomwolf's avatar
thomwolf committed
226
227
        super(Block, self).__init__()
        nx = config.n_embd
228
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
229
        self.attn = Attention(nx, n_ctx, config, scale)
230
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
231
232
        self.mlp = MLP(4 * nx, config)

233
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
234
235
236
        output_attn = self.attn(
            self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
        )
thomwolf's avatar
thomwolf committed
237
238
        a = output_attn[0]  # output_attn: a, present, (attentions)

thomwolf's avatar
thomwolf committed
239
        x = x + a
thomwolf's avatar
thomwolf committed
240
        m = self.mlp(self.ln_2(x))
thomwolf's avatar
thomwolf committed
241
        x = x + m
thomwolf's avatar
thomwolf committed
242
243
244

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)
thomwolf's avatar
thomwolf committed
245
246


247
class GPT2PreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
248
249
250
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
251

252
    config_class = GPT2Config
253
    pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
254
255
    load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
thomwolf's avatar
thomwolf committed
256

257
258
259
    def __init__(self, *inputs, **kwargs):
        super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)

260
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
261
262
        """ Initialize the weights.
        """
263
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
264
265
266
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
267
268
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
269
        elif isinstance(module, nn.LayerNorm):
thomwolf's avatar
thomwolf committed
270
271
272
273
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


thomwolf's avatar
thomwolf committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
GPT2_START_DOCSTRING = r"""    OpenAI GPT-2 model was proposed in
    `Language Models are Unsupervised Multitask Learners`_
    by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
    It's a causal (unidirectional) transformer pre-trained using  language modeling on a very large
    corpus of ~40 GB of text data.

    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.

    .. _`Language Models are Unsupervised Multitask Learners`:
        https://openai.com/blog/better-language-models/

    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module

    Parameters:
290
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
291
            Initializing with a config file does not load the weights associated with the model, only the configuration.
292
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
293
294
"""

thomwolf's avatar
thomwolf committed
295
GPT2_INPUTS_DOCSTRING = r"""    Inputs:
thomwolf's avatar
thomwolf committed
296
297
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
thomwolf's avatar
thomwolf committed
298
299
            GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.
300
301
302
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
            :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
thomwolf's avatar
thomwolf committed
303
304
305
        **past**:
            list of ``torch.FloatTensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
306
307
            (see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model 
            should not be passed as input ids as they have already been computed.
308
309
310
311
312
313
314
315
316
317
318
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            A parallel sequence of tokens (can be used to indicate various portions of the inputs).
            The embeddings from these tokens will be summed with the respective token embeddings.
            Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
thomwolf's avatar
thomwolf committed
319
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
thomwolf's avatar
thomwolf committed
320
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
321
            Mask values selected in ``[0, 1]``:
thomwolf's avatar
thomwolf committed
322
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
Julien Chaumond's avatar
Julien Chaumond committed
323
324
        **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
            Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
325
326
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
thomwolf's avatar
thomwolf committed
327
328
"""

329
330
331
332
333
334

@add_start_docstrings(
    "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
    GPT2_INPUTS_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
335
class GPT2Model(GPT2PreTrainedModel):
336
    r"""
thomwolf's avatar
thomwolf committed
337
338
339
340
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **past**:
341
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(2, batch_size, num_heads, sequence_length, embed_size_per_head)``:
thomwolf's avatar
thomwolf committed
342
            that contains pre-computed hidden-states (key and values in the attention blocks).
343
344
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model 
            should not be passed as input ids as they have already been computed.
thomwolf's avatar
thomwolf committed
345
346
347
348
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
349
350
351
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
352
353
354

    Examples::

355
356
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2Model.from_pretrained('gpt2')
357
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
358
359
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
thomwolf's avatar
thomwolf committed
360
361

    """
362

thomwolf's avatar
thomwolf committed
363
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
364
        super(GPT2Model, self).__init__(config)
thomwolf's avatar
thomwolf committed
365
366
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions
367
        self.output_past = config.output_past
thomwolf's avatar
thomwolf committed
368

thomwolf's avatar
thomwolf committed
369
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
thomwolf's avatar
thomwolf committed
370
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
371
        self.drop = nn.Dropout(config.embd_pdrop)
372
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
373
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
374

375
        self.init_weights()
thomwolf's avatar
thomwolf committed
376

thomwolf's avatar
thomwolf committed
377
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
378
        return self.wte
thomwolf's avatar
thomwolf committed
379

thomwolf's avatar
thomwolf committed
380
    def set_input_embeddings(self, new_embeddings):
381
382
        self.wte = new_embeddings

thomwolf's avatar
thomwolf committed
383
    def _prune_heads(self, heads_to_prune):
384
385
386
387
388
389
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

390
391
392
393
394
395
396
397
398
399
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
    ):
Julien Chaumond's avatar
Julien Chaumond committed
400
401
402
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
403
404
405
406
407
408
409
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

410
411
412
413
414
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

thomwolf's avatar
thomwolf committed
415
        if past is None:
thomwolf's avatar
thomwolf committed
416
            past_length = 0
thomwolf's avatar
thomwolf committed
417
            past = [None] * len(self.h)
thomwolf's avatar
thomwolf committed
418
        else:
thomwolf's avatar
thomwolf committed
419
            past_length = past[0][0].size(-2)
thomwolf's avatar
thomwolf committed
420
        if position_ids is None:
421
422
423
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
thomwolf's avatar
thomwolf committed
424

425
426
        # Attention mask.
        if attention_mask is not None:
427
            attention_mask = attention_mask.view(-1, input_shape[-1])
428
429
430
431
432
433
434
435
436
437
438
439
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
440
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
441
442
            attention_mask = (1.0 - attention_mask) * -10000.0

443
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
444
        # 1.0 in head_mask indicate we keep the head
445
        # attention_probs has shape bsz x n_heads x N x N
446
        # head_mask has shape n_layer x batch x n_heads x N x N
447
448
        if head_mask is not None:
            if head_mask.dim() == 1:
449
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
450
                head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
451
            elif head_mask.dim() == 2:
452
453
454
455
456
457
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
458
459
        else:
            head_mask = [None] * self.config.n_layer
460

461
462
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
thomwolf's avatar
thomwolf committed
463
464
465
466
467
468
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
469
470
        hidden_states = self.drop(hidden_states)

471
472
        output_shape = input_shape + (hidden_states.size(-1),)

473
        presents = ()
thomwolf's avatar
thomwolf committed
474
        all_attentions = []
475
        all_hidden_states = ()
476
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
thomwolf's avatar
thomwolf committed
477
            if self.output_hidden_states:
478
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
479

480
481
482
            outputs = block(
                hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
            )
483

thomwolf's avatar
thomwolf committed
484
            hidden_states, present = outputs[:2]
485
486
            if self.output_past:
                presents = presents + (present,)
thomwolf's avatar
thomwolf committed
487
488
489
490

            if self.output_attentions:
                all_attentions.append(outputs[2])

thomwolf's avatar
thomwolf committed
491
        hidden_states = self.ln_f(hidden_states)
492

thomwolf's avatar
thomwolf committed
493
494
495
        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
        if self.output_hidden_states:
496
            all_hidden_states = all_hidden_states + (hidden_states,)
thomwolf's avatar
thomwolf committed
497

498
499
500
        outputs = (hidden_states,)
        if self.output_past:
            outputs = outputs + (presents,)
thomwolf's avatar
thomwolf committed
501
        if self.output_hidden_states:
502
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
503
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
504
505
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
506
            all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
507
            outputs = outputs + (all_attentions,)
508
        return outputs  # last hidden state, (presents), (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
509
510


511
512
513
514
515
516
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """,
    GPT2_START_DOCSTRING,
    GPT2_INPUTS_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
517
class GPT2LMHeadModel(GPT2PreTrainedModel):
518
    r"""
thomwolf's avatar
thomwolf committed
519
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
thomwolf's avatar
thomwolf committed
520
521
522
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
LysandreJik's avatar
LysandreJik committed
523
            All labels set to ``-100`` are ignored (masked), the loss is only
thomwolf's avatar
thomwolf committed
524
525
526
            computed for labels in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
527
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
thomwolf's avatar
thomwolf committed
528
529
530
531
            Language modeling loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **past**:
532
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(2, batch_size, num_heads, sequence_length, embed_size_per_head)``:
thomwolf's avatar
thomwolf committed
533
            that contains pre-computed hidden-states (key and values in the attention blocks).
534
535
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model 
            should not be passed as input ids as they have already been computed.
thomwolf's avatar
thomwolf committed
536
537
538
539
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
540
541
542
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
543
544
545

    Examples::

thomwolf's avatar
thomwolf committed
546
        import torch
547
        from transformers import GPT2Tokenizer, GPT2LMHeadModel
thomwolf's avatar
thomwolf committed
548

549
550
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2LMHeadModel.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
551

552
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
553
554
        outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]
thomwolf's avatar
thomwolf committed
555
556

    """
557

thomwolf's avatar
thomwolf committed
558
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
559
        super(GPT2LMHeadModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
560
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
561
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
562

563
        self.init_weights()
564

thomwolf's avatar
thomwolf committed
565
    def get_output_embeddings(self):
566
        return self.lm_head
thomwolf's avatar
thomwolf committed
567

568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
thomwolf's avatar
thomwolf committed
588
        hidden_states = transformer_outputs[0]
589

thomwolf's avatar
thomwolf committed
590
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
591

592
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
593
        if labels is not None:
594
            # Shift so that tokens < n predict n
595
            shift_logits = lm_logits[..., :-1, :].contiguous()
thomwolf's avatar
thomwolf committed
596
            shift_labels = labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
597
            # Flatten the tokens
LysandreJik's avatar
LysandreJik committed
598
            loss_fct = CrossEntropyLoss()
599
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
600
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
601
602

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
603
604


605
606
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
thomwolf's avatar
thomwolf committed
607
608
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
The language modeling head has its weights tied to the input embeddings,
Julien Chaumond's avatar
Julien Chaumond committed
609
the classification head takes as input the input of a specified classification token index in the input sequence).
610
611
612
613
""",
    GPT2_START_DOCSTRING,
    GPT2_INPUTS_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
614
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
615
    r"""
thomwolf's avatar
thomwolf committed
616
        **mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
thomwolf's avatar
thomwolf committed
617
618
619
620
621
622
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
        **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
LysandreJik's avatar
LysandreJik committed
623
            All labels set to ``-100`` are ignored (masked), the loss is only
thomwolf's avatar
thomwolf committed
624
            computed for labels in ``[0, ..., config.vocab_size]``
625
        **mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
thomwolf's avatar
thomwolf committed
626
627
628
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
thomwolf's avatar
thomwolf committed
629

thomwolf's avatar
thomwolf committed
630
631
632
633
634
635
636
637
638
639
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Language modeling loss.
        **mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Multiple choice classification loss.
        **lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
            Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
        **past**:
640
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(2, batch_size, num_heads, sequence_length, embed_size_per_head)``:
thomwolf's avatar
thomwolf committed
641
            that contains pre-computed hidden-states (key and values in the attention blocks).
642
643
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model 
            should not be passed as input ids as they have already been computed.
thomwolf's avatar
thomwolf committed
644
645
646
647
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
648
649
650
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
651
652
653

    Examples::

654
        import torch
655
        from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
656
        
657
658
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
thomwolf's avatar
thomwolf committed
659
660
661
662
663
664
        
        # Add a [CLS] to the vocabulary (we should train it also!)
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary
        
thomwolf's avatar
thomwolf committed
665
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
thomwolf's avatar
thomwolf committed
666
667
668
669
670
671
672
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

        input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
673
        lm_prediction_scores, mc_prediction_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
674
675

    """
676

thomwolf's avatar
thomwolf committed
677
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
678
        super(GPT2DoubleHeadsModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
679
        config.num_labels = 1
thomwolf's avatar
thomwolf committed
680
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
681
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
682
        self.multiple_choice_head = SequenceSummary(config)
thomwolf's avatar
thomwolf committed
683

684
        self.init_weights()
thomwolf's avatar
thomwolf committed
685

thomwolf's avatar
thomwolf committed
686
    def get_output_embeddings(self):
687
        return self.lm_head
thomwolf's avatar
thomwolf committed
688

689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        lm_labels=None,
        mc_labels=None,
    ):
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
711

thomwolf's avatar
thomwolf committed
712
        hidden_states = transformer_outputs[0]
713

thomwolf's avatar
thomwolf committed
714
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
715
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
716

717
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
718
719
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
720
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
721
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
722
        if lm_labels is not None:
723
724
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
725
            loss_fct = CrossEntropyLoss()
726
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
727
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
728
729

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)