modeling_gpt2.py 35.3 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

18

thomwolf's avatar
thomwolf committed
19
20
import logging
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
21
import warnings
thomwolf's avatar
thomwolf committed
22
23
24
25
26

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

27
from .activations import ACT2FN
28
from .configuration_gpt2 import GPT2Config
29
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
30
31
32
33
34
35
36
from .modeling_utils import (
    Conv1D,
    PreTrainedModel,
    SequenceSummary,
    find_pruneable_heads_and_indices,
    prune_conv1d_layer,
)
Aymeric Augustin's avatar
Aymeric Augustin committed
37

thomwolf's avatar
thomwolf committed
38
39
40

logger = logging.getLogger(__name__)

41
42
43
44
45
46
47
48
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "gpt2",
    "gpt2-medium",
    "gpt2-large",
    "gpt2-xl",
    "distilgpt2",
    # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
]
49

thomwolf's avatar
thomwolf committed
50

51
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
thomwolf's avatar
thomwolf committed
52
53
54
55
56
57
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import re
        import tensorflow as tf
    except ImportError:
58
59
60
61
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
62
63
        raise
    tf_path = os.path.abspath(gpt2_checkpoint_path)
thomwolf's avatar
thomwolf committed
64
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
thomwolf's avatar
thomwolf committed
65
66
67
68
69
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
70
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
thomwolf's avatar
thomwolf committed
71
72
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
thomwolf's avatar
thomwolf committed
73
        arrays.append(array.squeeze())
thomwolf's avatar
thomwolf committed
74
75

    for name, array in zip(names, arrays):
thomwolf's avatar
thomwolf committed
76
        name = name[6:]  # skip "model/"
77
        name = name.split("/")
thomwolf's avatar
thomwolf committed
78
79
        pointer = model
        for m_name in name:
80
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
81
                scope_names = re.split(r"(\d+)", m_name)
thomwolf's avatar
thomwolf committed
82
            else:
83
84
                scope_names = [m_name]
            if scope_names[0] == "w" or scope_names[0] == "g":
85
                pointer = getattr(pointer, "weight")
86
            elif scope_names[0] == "b":
87
                pointer = getattr(pointer, "bias")
88
89
            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
                pointer = getattr(pointer, scope_names[0])
90
                pointer = getattr(pointer, "weight")
thomwolf's avatar
thomwolf committed
91
            else:
92
93
94
                pointer = getattr(pointer, scope_names[0])
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
95
96
97
98
99
100
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
101
        logger.info("Initialize PyTorch weight {}".format(name))
thomwolf's avatar
thomwolf committed
102
103
104
105
106
        pointer.data = torch.from_numpy(array)
    return model


class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
107
    def __init__(self, nx, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
108
        super().__init__()
thomwolf's avatar
thomwolf committed
109

thomwolf's avatar
thomwolf committed
110
111
112
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
113
114
115
116
        self.register_buffer(
            "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))
thomwolf's avatar
thomwolf committed
117
118
119
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
120

thomwolf's avatar
thomwolf committed
121
122
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
123
124
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
125
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
126

127
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
128
129
        if len(heads) == 0:
            return
130
131
132
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
        )
133
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
134

135
136
137
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
138

139
140
141
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
142
        self.pruned_heads = self.pruned_heads.union(heads)
143

144
    def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
thomwolf's avatar
thomwolf committed
145
146
        w = torch.matmul(q, k)
        if self.scale:
147
            w = w / (float(v.size(-1)) ** 0.5)
thomwolf's avatar
thomwolf committed
148
        nd, ns = w.size(-2), w.size(-1)
149
        mask = self.bias[:, :, ns - nd : ns, :ns]
150
        w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
thomwolf's avatar
thomwolf committed
151

152
153
154
155
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

thomwolf's avatar
thomwolf committed
156
        w = nn.Softmax(dim=-1)(w)
157
        w = self.attn_dropout(w)
158
159
160
161
162

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
163
        outputs = [torch.matmul(w, v)]
164
        if output_attentions:
thomwolf's avatar
thomwolf committed
165
166
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
167
168
169
170
171
172
173
174
175
176

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
thomwolf's avatar
thomwolf committed
177
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
thomwolf's avatar
thomwolf committed
178
        else:
thomwolf's avatar
thomwolf committed
179
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
thomwolf's avatar
thomwolf committed
180

181
182
183
    def forward(
        self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
    ):
thomwolf's avatar
thomwolf committed
184
185
186
187
188
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
thomwolf's avatar
thomwolf committed
189
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
190
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
thomwolf's avatar
thomwolf committed
191
            key = torch.cat((past_key, key), dim=-1)
thomwolf's avatar
thomwolf committed
192
            value = torch.cat((past_value, value), dim=-2)
193
194
195
196
197

        if use_cache is True:
            present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
        else:
            present = (None,)
198

199
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
thomwolf's avatar
thomwolf committed
200
        a = attn_outputs[0]
201

thomwolf's avatar
thomwolf committed
202
203
        a = self.merge_heads(a)
        a = self.c_proj(a)
204
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
205
206
207

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)
thomwolf's avatar
thomwolf committed
208
209
210
211


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
Julien Chaumond's avatar
Julien Chaumond committed
212
        super().__init__()
thomwolf's avatar
thomwolf committed
213
214
215
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
216
        self.act = ACT2FN[config.activation_function]
217
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
218
219
220
221

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
222
        return self.dropout(h2)
thomwolf's avatar
thomwolf committed
223
224
225


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
226
    def __init__(self, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
227
        super().__init__()
thomwolf's avatar
thomwolf committed
228
        nx = config.n_embd
229
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
230
        self.attn = Attention(nx, n_ctx, config, scale)
231
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
232
233
        self.mlp = MLP(4 * nx, config)

234
235
236
    def forward(
        self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False,
    ):
237
        output_attn = self.attn(
238
239
240
241
242
            self.ln_1(x),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
243
            output_attentions=output_attentions,
244
        )
thomwolf's avatar
thomwolf committed
245
246
        a = output_attn[0]  # output_attn: a, present, (attentions)

thomwolf's avatar
thomwolf committed
247
        x = x + a
thomwolf's avatar
thomwolf committed
248
        m = self.mlp(self.ln_2(x))
thomwolf's avatar
thomwolf committed
249
        x = x + m
thomwolf's avatar
thomwolf committed
250
251
252

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)
thomwolf's avatar
thomwolf committed
253
254


255
class GPT2PreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
256
    """ An abstract class to handle weights initialization and
257
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
258
    """
259

260
261
262
    config_class = GPT2Config
    load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
thomwolf's avatar
thomwolf committed
263

264
    def __init__(self, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
265
        super().__init__(*inputs, **kwargs)
266

267
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
268
269
        """ Initialize the weights.
        """
270
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
271
272
273
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
274
275
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
276
        elif isinstance(module, nn.LayerNorm):
thomwolf's avatar
thomwolf committed
277
278
279
280
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


Lysandre's avatar
Lysandre committed
281
282
283
284
GPT2_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
Lysandre's avatar
Fixes  
Lysandre committed
285
    usage and behavior.
thomwolf's avatar
thomwolf committed
286
287

    Parameters:
288
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
289
            Initializing with a config file does not load the weights associated with the model, only the configuration.
290
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
291
292
"""

Lysandre's avatar
Lysandre committed
293
GPT2_INPUTS_DOCSTRING = r"""
294
    Args:
295
296
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
            :obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
Lysandre's avatar
Lysandre committed
297
            Indices of input sequence tokens in the vocabulary.
298
299

            If `past` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`.
Lysandre's avatar
Lysandre committed
300

301
302
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
303
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
Lysandre's avatar
Lysandre committed
304

305
            `What are input IDs? <../glossary.html#input-ids>`__
306

307
308
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
309
            (see `past` output below). Can be used to speed up sequential decoding.
310
            The `input_ids` which have their past given to this model should not be passed as `input_ids` as they have already been computed.
311
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
312
313
314
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
315

316
            `What are attention masks? <../glossary.html#attention-mask>`__
317
318
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`, defaults to :obj:`None`):
            `input_ids_length` = `sequence_length if `past` is None else 1
319
320
321
322
323
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
324
325
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
326

327
328
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
329
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
330
            Mask values selected in ``[0, 1]``:
331
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
flozi00's avatar
flozi00 committed
332
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
333
334
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
flozi00's avatar
flozi00 committed
335
            If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`).
336
337
        use_cache (:obj:`bool`):
            If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
ZhuBaohe's avatar
ZhuBaohe committed
338
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
339
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
thomwolf's avatar
thomwolf committed
340
341
"""

342
343
344
345
346

@add_start_docstrings(
    "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
347
class GPT2Model(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
348
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
349
        super().__init__(config)
thomwolf's avatar
thomwolf committed
350

thomwolf's avatar
thomwolf committed
351
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
thomwolf's avatar
thomwolf committed
352
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
353
        self.drop = nn.Dropout(config.embd_pdrop)
354
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
355
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
356

357
        self.init_weights()
thomwolf's avatar
thomwolf committed
358

thomwolf's avatar
thomwolf committed
359
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
360
        return self.wte
thomwolf's avatar
thomwolf committed
361

thomwolf's avatar
thomwolf committed
362
    def set_input_embeddings(self, new_embeddings):
363
364
        self.wte = new_embeddings

thomwolf's avatar
thomwolf committed
365
    def _prune_heads(self, heads_to_prune):
366
367
368
369
370
371
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

372
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
373
374
375
376
377
378
379
380
381
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
382
        use_cache=None,
383
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
384
        output_hidden_states=None,
385
    ):
386
387
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
388
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
389
390
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the last layer of the model.
391
            If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
392
393
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
394
            Can be used (see `past` input) to speed up sequential decoding.
Joseph Liu's avatar
Joseph Liu committed
395
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``) is passed or when ``config.output_hidden_states=True``:
396
397
398
399
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
400
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
401
402
403
404
405
406
407
408
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

Lysandre's avatar
Lysandre committed
409
410
411
        from transformers import GPT2Tokenizer, GPT2Model
        import torch

412
413
414
415
416
417
418
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2Model.from_pretrained('gpt2')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

        """
419
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
420
421
422
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
423
        use_cache = use_cache if use_cache is not None else self.config.use_cache
424

Julien Chaumond's avatar
Julien Chaumond committed
425
426
427
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
428
429
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
430
            batch_size = input_ids.shape[0]
431
432
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
433
            batch_size = inputs_embeds.shape[0]
434
435
436
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

437
438
439
440
441
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

thomwolf's avatar
thomwolf committed
442
        if past is None:
thomwolf's avatar
thomwolf committed
443
            past_length = 0
thomwolf's avatar
thomwolf committed
444
            past = [None] * len(self.h)
thomwolf's avatar
thomwolf committed
445
        else:
thomwolf's avatar
thomwolf committed
446
            past_length = past[0][0].size(-2)
thomwolf's avatar
thomwolf committed
447
        if position_ids is None:
448
449
450
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
thomwolf's avatar
thomwolf committed
451

452
453
        # Attention mask.
        if attention_mask is not None:
454
            assert batch_size > 0, "batch_size has to be defined and > 0"
455
            attention_mask = attention_mask.view(batch_size, -1)
456
457
458
459
460
461
462
463
464
465
466
467
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
468
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
469
470
            attention_mask = (1.0 - attention_mask) * -10000.0

471
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
472
        # 1.0 in head_mask indicate we keep the head
473
        # attention_probs has shape bsz x n_heads x N x N
474
        # head_mask has shape n_layer x batch x n_heads x N x N
475
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
476

477
478
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
thomwolf's avatar
thomwolf committed
479
480
481
482
483
484
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
485
486
        hidden_states = self.drop(hidden_states)

487
488
        output_shape = input_shape + (hidden_states.size(-1),)

489
        presents = ()
thomwolf's avatar
thomwolf committed
490
        all_attentions = []
491
        all_hidden_states = ()
492
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
Joseph Liu's avatar
Joseph Liu committed
493
            if output_hidden_states:
494
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
495

496
            outputs = block(
497
498
499
500
501
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
502
                output_attentions=output_attentions,
503
            )
504

thomwolf's avatar
thomwolf committed
505
            hidden_states, present = outputs[:2]
506
            if use_cache is True:
507
                presents = presents + (present,)
thomwolf's avatar
thomwolf committed
508

509
            if output_attentions:
thomwolf's avatar
thomwolf committed
510
511
                all_attentions.append(outputs[2])

thomwolf's avatar
thomwolf committed
512
        hidden_states = self.ln_f(hidden_states)
513

thomwolf's avatar
thomwolf committed
514
515
        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
Joseph Liu's avatar
Joseph Liu committed
516
        if output_hidden_states:
517
            all_hidden_states = all_hidden_states + (hidden_states,)
thomwolf's avatar
thomwolf committed
518

519
        outputs = (hidden_states,)
520
        if use_cache is True:
521
            outputs = outputs + (presents,)
Joseph Liu's avatar
Joseph Liu committed
522
        if output_hidden_states:
523
            outputs = outputs + (all_hidden_states,)
524
        if output_attentions:
thomwolf's avatar
thomwolf committed
525
526
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
527
            all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
528
            outputs = outputs + (all_attentions,)
529
        return outputs  # last hidden state, (presents), (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
530
531


532
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
533
    """The GPT2 Model transformer with a language modeling head on top
534
    (linear layer with weights tied to the input embeddings). """,
535
536
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
537
class GPT2LMHeadModel(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
538
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
539
        super().__init__(config)
thomwolf's avatar
thomwolf committed
540
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
541
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
542

543
        self.init_weights()
544

thomwolf's avatar
thomwolf committed
545
    def get_output_embeddings(self):
546
        return self.lm_head
thomwolf's avatar
thomwolf committed
547

548
    def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
549
        # only last token for inputs_ids if past is defined in kwargs
550
        if past:
551
            input_ids = input_ids[:, -1].unsqueeze(-1)
552

553
        return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
554

555
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
556
557
558
559
560
561
562
563
564
565
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
566
        use_cache=None,
567
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
568
        output_hidden_states=None,
569
    ):
570
571
572
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for language modeling.
573
            Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Lysandre's avatar
Lysandre committed
574
575
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
            All labels set to ``-100`` are ignored (masked), the loss is only
576
577
578
            computed for labels in ``[0, ..., config.vocab_size]``

    Return:
Lysandre's avatar
Fixes  
Lysandre committed
579
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
580
581
582
583
584
585
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
            Language modeling loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
586
            Can be used (see `past` input) to speed up sequential decoding.
Joseph Liu's avatar
Joseph Liu committed
587
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
588
589
590
591
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
592
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        import torch
        from transformers import GPT2Tokenizer, GPT2LMHeadModel

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2LMHeadModel.from_pretrained('gpt2')

        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]

        """
612
613
614
615
616
617
618
619
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
620
            use_cache=use_cache,
621
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
622
            output_hidden_states=output_hidden_states,
623
        )
thomwolf's avatar
thomwolf committed
624
        hidden_states = transformer_outputs[0]
625

thomwolf's avatar
thomwolf committed
626
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
627

628
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
629
        if labels is not None:
630
            # Shift so that tokens < n predict n
631
            shift_logits = lm_logits[..., :-1, :].contiguous()
thomwolf's avatar
thomwolf committed
632
            shift_labels = labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
633
            # Flatten the tokens
LysandreJik's avatar
LysandreJik committed
634
            loss_fct = CrossEntropyLoss()
635
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
636
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
637
638

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
639
640


641
642
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
643
644
645
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
646
647
648
""",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
649
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
    def __init__(self, config):
        super().__init__(config)
        config.num_labels = 1
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.multiple_choice_head = SequenceSummary(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
673
        labels=None,
674
        mc_labels=None,
675
        use_cache=None,
676
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
677
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
678
        **kwargs
679
680
681
    ):
        r"""
        mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
thomwolf's avatar
thomwolf committed
682
683
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
Sylvain Gugger's avatar
Sylvain Gugger committed
684
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
685
            Labels for language modeling.
Sylvain Gugger's avatar
Sylvain Gugger committed
686
            Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
thomwolf's avatar
thomwolf committed
687
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
688
            All labels set to ``-100`` are ignored (masked), the loss is only
thomwolf's avatar
thomwolf committed
689
            computed for labels in ``[0, ..., config.vocab_size]``
690
        mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
691
692
693
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
Sylvain Gugger's avatar
Sylvain Gugger committed
694
695
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
thomwolf's avatar
thomwolf committed
696

697
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
698
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
Sylvain Gugger's avatar
Sylvain Gugger committed
699
        lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
thomwolf's avatar
thomwolf committed
700
            Language modeling loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
701
        mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
thomwolf's avatar
thomwolf committed
702
            Multiple choice classification loss.
703
        lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
704
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
705
706
707
708
        mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
709
            Can be used (see `past` input) to speed up sequential decoding.
Joseph Liu's avatar
Joseph Liu committed
710
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
711
712
713
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
714
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
715
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
716
717
718
719
720
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
thomwolf committed
721
722
723

    Examples::

724
        import torch
725
        from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
726

727
728
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
729

thomwolf's avatar
thomwolf committed
730
731
732
733
        # Add a [CLS] to the vocabulary (we should train it also!)
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary
734

thomwolf's avatar
thomwolf committed
735
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
thomwolf's avatar
thomwolf committed
736
737
738
739
740
741
742
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

        input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
743
        lm_prediction_scores, mc_prediction_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
744

745
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
746
747
748
749
750
751
752
753
        if "lm_labels" in kwargs:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                DeprecationWarning,
            )
            labels = kwargs.pop("lm_labels")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

754
755
756
757
758
759
760
761
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
762
            use_cache=use_cache,
763
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
764
            output_hidden_states=output_hidden_states,
765
        )
766

thomwolf's avatar
thomwolf committed
767
        hidden_states = transformer_outputs[0]
768

thomwolf's avatar
thomwolf committed
769
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
770
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
771

772
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
773
774
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
775
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
776
            outputs = (loss,) + outputs
Sylvain Gugger's avatar
Sylvain Gugger committed
777
        if labels is not None:
778
            shift_logits = lm_logits[..., :-1, :].contiguous()
Sylvain Gugger's avatar
Sylvain Gugger committed
779
            shift_labels = labels[..., 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
780
            loss_fct = CrossEntropyLoss()
781
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
782
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
783
784

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)