modeling_gpt2.py 34.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

18

thomwolf's avatar
thomwolf committed
19
20
21
22
23
24
25
import logging
import os

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

26
from .activations import ACT2FN
27
from .configuration_gpt2 import GPT2Config
28
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
Aymeric Augustin's avatar
Aymeric Augustin committed
29
30
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer

thomwolf's avatar
thomwolf committed
31
32
33

logger = logging.getLogger(__name__)

34
35
36
37
38
39
40
41
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
    "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
    "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
    "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-pytorch_model.bin",
    "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",
}

thomwolf's avatar
thomwolf committed
42

43
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
thomwolf's avatar
thomwolf committed
44
45
46
47
48
49
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import re
        import tensorflow as tf
    except ImportError:
50
51
52
53
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
54
55
        raise
    tf_path = os.path.abspath(gpt2_checkpoint_path)
thomwolf's avatar
thomwolf committed
56
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
thomwolf's avatar
thomwolf committed
57
58
59
60
61
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
62
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
thomwolf's avatar
thomwolf committed
63
64
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
thomwolf's avatar
thomwolf committed
65
        arrays.append(array.squeeze())
thomwolf's avatar
thomwolf committed
66
67

    for name, array in zip(names, arrays):
thomwolf's avatar
thomwolf committed
68
        name = name[6:]  # skip "model/"
69
        name = name.split("/")
thomwolf's avatar
thomwolf committed
70
71
        pointer = model
        for m_name in name:
72
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
73
                scope_names = re.split(r"(\d+)", m_name)
thomwolf's avatar
thomwolf committed
74
            else:
75
76
                scope_names = [m_name]
            if scope_names[0] == "w" or scope_names[0] == "g":
77
                pointer = getattr(pointer, "weight")
78
            elif scope_names[0] == "b":
79
                pointer = getattr(pointer, "bias")
80
81
            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
                pointer = getattr(pointer, scope_names[0])
82
                pointer = getattr(pointer, "weight")
thomwolf's avatar
thomwolf committed
83
            else:
84
85
86
                pointer = getattr(pointer, scope_names[0])
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
87
88
89
90
91
92
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
93
        logger.info("Initialize PyTorch weight {}".format(name))
thomwolf's avatar
thomwolf committed
94
95
96
97
98
        pointer.data = torch.from_numpy(array)
    return model


class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
99
    def __init__(self, nx, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
100
        super().__init__()
thomwolf's avatar
thomwolf committed
101
102
        self.output_attentions = config.output_attentions

thomwolf's avatar
thomwolf committed
103
104
105
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
106
107
108
109
        self.register_buffer(
            "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))
thomwolf's avatar
thomwolf committed
110
111
112
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
113

thomwolf's avatar
thomwolf committed
114
115
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
116
117
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
118
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
119

120
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
121
122
        if len(heads) == 0:
            return
123
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
124
        heads = set(heads) - self.pruned_heads  # Convert to set and emove already pruned heads
125
        for head in heads:
126
127
            # Compute how many pruned heads are before the head and move the index accordingly
            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
128
129
130
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
131
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
132

133
134
135
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
136

137
138
139
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
140
        self.pruned_heads = self.pruned_heads.union(heads)
141

142
    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
143
144
        w = torch.matmul(q, k)
        if self.scale:
145
            w = w / (v.size(-1) ** 0.5)
thomwolf's avatar
thomwolf committed
146
        nd, ns = w.size(-2), w.size(-1)
147
148
        mask = self.bias[:, :, ns - nd : ns, :ns]
        w = torch.where(mask, w, self.masked_bias)
thomwolf's avatar
thomwolf committed
149

150
151
152
153
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

thomwolf's avatar
thomwolf committed
154
        w = nn.Softmax(dim=-1)(w)
155
        w = self.attn_dropout(w)
156
157
158
159
160

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
161
        outputs = [torch.matmul(w, v)]
thomwolf's avatar
thomwolf committed
162
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
163
164
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
165
166
167
168
169
170
171
172
173
174

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
thomwolf's avatar
thomwolf committed
175
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
thomwolf's avatar
thomwolf committed
176
        else:
thomwolf's avatar
thomwolf committed
177
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
thomwolf's avatar
thomwolf committed
178

179
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
thomwolf's avatar
thomwolf committed
180
181
182
183
184
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
thomwolf's avatar
thomwolf committed
185
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
186
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
thomwolf's avatar
thomwolf committed
187
            key = torch.cat((past_key, key), dim=-1)
thomwolf's avatar
thomwolf committed
188
            value = torch.cat((past_value, value), dim=-2)
189
190
191
192
193

        if use_cache is True:
            present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
        else:
            present = (None,)
194

195
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
thomwolf's avatar
thomwolf committed
196
        a = attn_outputs[0]
197

thomwolf's avatar
thomwolf committed
198
199
        a = self.merge_heads(a)
        a = self.c_proj(a)
200
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
201
202
203

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)
thomwolf's avatar
thomwolf committed
204
205
206
207


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
Julien Chaumond's avatar
Julien Chaumond committed
208
        super().__init__()
thomwolf's avatar
thomwolf committed
209
210
211
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
212
        self.act = ACT2FN[config.activation_function]
213
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
214
215
216
217

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
218
        return self.dropout(h2)
thomwolf's avatar
thomwolf committed
219
220
221


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
222
    def __init__(self, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
223
        super().__init__()
thomwolf's avatar
thomwolf committed
224
        nx = config.n_embd
225
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
226
        self.attn = Attention(nx, n_ctx, config, scale)
227
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
228
229
        self.mlp = MLP(4 * nx, config)

230
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
231
        output_attn = self.attn(
232
233
234
235
236
            self.ln_1(x),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
237
        )
thomwolf's avatar
thomwolf committed
238
239
        a = output_attn[0]  # output_attn: a, present, (attentions)

thomwolf's avatar
thomwolf committed
240
        x = x + a
thomwolf's avatar
thomwolf committed
241
        m = self.mlp(self.ln_2(x))
thomwolf's avatar
thomwolf committed
242
        x = x + m
thomwolf's avatar
thomwolf committed
243
244
245

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)
thomwolf's avatar
thomwolf committed
246
247


248
class GPT2PreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
249
    """ An abstract class to handle weights initialization and
250
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
251
    """
252

253
    config_class = GPT2Config
254
    pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
255
256
    load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
thomwolf's avatar
thomwolf committed
257

258
    def __init__(self, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
259
        super().__init__(*inputs, **kwargs)
260

261
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
262
263
        """ Initialize the weights.
        """
264
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
265
266
267
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
268
269
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
270
        elif isinstance(module, nn.LayerNorm):
thomwolf's avatar
thomwolf committed
271
272
273
274
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


Lysandre's avatar
Lysandre committed
275
276
277
278
GPT2_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
Lysandre's avatar
Fixes  
Lysandre committed
279
    usage and behavior.
thomwolf's avatar
thomwolf committed
280
281

    Parameters:
282
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
283
            Initializing with a config file does not load the weights associated with the model, only the configuration.
284
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
285
286
"""

Lysandre's avatar
Lysandre committed
287
GPT2_INPUTS_DOCSTRING = r"""
288
    Args:
289
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Lysandre's avatar
Lysandre committed
290
            Indices of input sequence tokens in the vocabulary.
291
            If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Lysandre's avatar
Lysandre committed
292

293
294
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
295
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
Lysandre's avatar
Lysandre committed
296

297
            `What are input IDs? <../glossary.html#input-ids>`__
298

299
300
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
301
302
            (see `past` output below). Can be used to speed up sequential decoding.
            If `past` is used, the user can optionally input only the last `input_ids` (those that don't have their past given to this model) of shape :obj:`(batch_size, 1)` instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
303
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
304
305
306
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
307

308
            `What are attention masks? <../glossary.html#attention-mask>`__
309
310
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`, defaults to :obj:`None`):
            `input_ids_length` = `sequence_length if `past` is None else 1
311
312
313
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
314
            If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
Lysandre's avatar
Lysandre committed
315

316
317
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
318
319
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
320

321
322
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
323
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
324
            Mask values selected in ``[0, 1]``:
325
326
327
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
        input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
328
329
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
330
331
332
            If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
        use_cache (:obj:`bool`):
            If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
thomwolf's avatar
thomwolf committed
333
334
"""

335
336
337
338
339

@add_start_docstrings(
    "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
340
class GPT2Model(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
341
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
342
        super().__init__(config)
thomwolf's avatar
thomwolf committed
343
344
345
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions

thomwolf's avatar
thomwolf committed
346
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
thomwolf's avatar
thomwolf committed
347
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
348
        self.drop = nn.Dropout(config.embd_pdrop)
349
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
350
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
351

352
        self.init_weights()
thomwolf's avatar
thomwolf committed
353

thomwolf's avatar
thomwolf committed
354
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
355
        return self.wte
thomwolf's avatar
thomwolf committed
356

thomwolf's avatar
thomwolf committed
357
    def set_input_embeddings(self, new_embeddings):
358
359
        self.wte = new_embeddings

thomwolf's avatar
thomwolf committed
360
    def _prune_heads(self, heads_to_prune):
361
362
363
364
365
366
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

367
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
368
369
370
371
372
373
374
375
376
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
377
        use_cache=True,
378
    ):
379
380
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
381
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
382
383
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the last layer of the model.
384
            If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
385
386
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
387
            Can be used (see `past` input) to speed up sequential decoding.
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

Lysandre's avatar
Lysandre committed
402
403
404
        from transformers import GPT2Tokenizer, GPT2Model
        import torch

405
406
407
408
409
410
411
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2Model.from_pretrained('gpt2')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

        """
412
413
414
415
416
417
418
419
420
421
422

        # If using past key value states, only the last tokens
        # should be given as an input
        if past is not None:
            if input_ids is not None:
                input_ids = input_ids[:, -1:]
            if inputs_embeds is not None:
                inputs_embeds = inputs_embeds[:, -1:]
            if token_type_ids is not None:
                token_type_ids = token_type_ids[:, -1:]

Julien Chaumond's avatar
Julien Chaumond committed
423
424
425
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
426
427
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
428
            batch_size = input_ids.shape[0]
429
430
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
431
            batch_size = inputs_embeds.shape[0]
432
433
434
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

435
436
437
438
439
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

thomwolf's avatar
thomwolf committed
440
        if past is None:
thomwolf's avatar
thomwolf committed
441
            past_length = 0
thomwolf's avatar
thomwolf committed
442
            past = [None] * len(self.h)
thomwolf's avatar
thomwolf committed
443
        else:
thomwolf's avatar
thomwolf committed
444
            past_length = past[0][0].size(-2)
thomwolf's avatar
thomwolf committed
445
        if position_ids is None:
446
447
448
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
thomwolf's avatar
thomwolf committed
449

450
451
        # Attention mask.
        if attention_mask is not None:
452
            assert batch_size > 0, "batch_size has to be defined and > 0"
453
            attention_mask = attention_mask.view(batch_size, -1)
454
455
456
457
458
459
460
461
462
463
464
465
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
466
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
467
468
            attention_mask = (1.0 - attention_mask) * -10000.0

469
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
470
        # 1.0 in head_mask indicate we keep the head
471
        # attention_probs has shape bsz x n_heads x N x N
472
        # head_mask has shape n_layer x batch x n_heads x N x N
473
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
474

475
476
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
thomwolf's avatar
thomwolf committed
477
478
479
480
481
482
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
483
484
        hidden_states = self.drop(hidden_states)

485
486
        output_shape = input_shape + (hidden_states.size(-1),)

487
        presents = ()
thomwolf's avatar
thomwolf committed
488
        all_attentions = []
489
        all_hidden_states = ()
490
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
thomwolf's avatar
thomwolf committed
491
            if self.output_hidden_states:
492
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
493

494
            outputs = block(
495
496
497
498
499
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
500
            )
501

thomwolf's avatar
thomwolf committed
502
            hidden_states, present = outputs[:2]
503
            if use_cache is True:
504
                presents = presents + (present,)
thomwolf's avatar
thomwolf committed
505
506
507
508

            if self.output_attentions:
                all_attentions.append(outputs[2])

thomwolf's avatar
thomwolf committed
509
        hidden_states = self.ln_f(hidden_states)
510

thomwolf's avatar
thomwolf committed
511
512
513
        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
        if self.output_hidden_states:
514
            all_hidden_states = all_hidden_states + (hidden_states,)
thomwolf's avatar
thomwolf committed
515

516
        outputs = (hidden_states,)
517
        if use_cache is True:
518
            outputs = outputs + (presents,)
thomwolf's avatar
thomwolf committed
519
        if self.output_hidden_states:
520
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
521
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
522
523
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
524
            all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
525
            outputs = outputs + (all_attentions,)
526
        return outputs  # last hidden state, (presents), (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
527
528


529
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
530
    """The GPT2 Model transformer with a language modeling head on top
531
    (linear layer with weights tied to the input embeddings). """,
532
533
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
534
class GPT2LMHeadModel(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
535
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
536
        super().__init__(config)
thomwolf's avatar
thomwolf committed
537
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
538
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
539

540
        self.init_weights()
541

thomwolf's avatar
thomwolf committed
542
    def get_output_embeddings(self):
543
        return self.lm_head
thomwolf's avatar
thomwolf committed
544

545
    def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
546
        # only last token for inputs_ids if past is defined in kwargs
547
        if past:
548
            input_ids = input_ids[:, -1].unsqueeze(-1)
549

550
        return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
551

552
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
553
554
555
556
557
558
559
560
561
562
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
563
        use_cache=True,
564
    ):
565
566
567
568
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Lysandre's avatar
Lysandre committed
569
570
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
            All labels set to ``-100`` are ignored (masked), the loss is only
571
572
573
            computed for labels in ``[0, ..., config.vocab_size]``

    Return:
Lysandre's avatar
Fixes  
Lysandre committed
574
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
575
576
577
578
579
580
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
            Language modeling loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
581
            Can be used (see `past` input) to speed up sequential decoding.
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        import torch
        from transformers import GPT2Tokenizer, GPT2LMHeadModel

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2LMHeadModel.from_pretrained('gpt2')

        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]

        """
607
608
609
610
611
612
613
614
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
615
            use_cache=use_cache,
616
        )
thomwolf's avatar
thomwolf committed
617
        hidden_states = transformer_outputs[0]
618

thomwolf's avatar
thomwolf committed
619
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
620

621
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
622
        if labels is not None:
623
            # Shift so that tokens < n predict n
624
            shift_logits = lm_logits[..., :-1, :].contiguous()
thomwolf's avatar
thomwolf committed
625
            shift_labels = labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
626
            # Flatten the tokens
LysandreJik's avatar
LysandreJik committed
627
            loss_fct = CrossEntropyLoss()
628
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
629
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
630
631

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
632
633


634
635
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
636
637
638
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
639
640
641
""",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
642
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
    def __init__(self, config):
        super().__init__(config)
        config.num_labels = 1
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.multiple_choice_head = SequenceSummary(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        lm_labels=None,
        mc_labels=None,
668
        use_cache=True,
669
670
671
    ):
        r"""
        mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
thomwolf's avatar
thomwolf committed
672
673
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
674
        lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
675
676
677
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
678
            All labels set to ``-100`` are ignored (masked), the loss is only
thomwolf's avatar
thomwolf committed
679
            computed for labels in ``[0, ..., config.vocab_size]``
680
        mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
681
682
683
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
thomwolf's avatar
thomwolf committed
684

685
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
686
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
687
        lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``lm_labels`` is provided):
thomwolf's avatar
thomwolf committed
688
            Language modeling loss.
689
        mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`multiple_choice_labels` is provided):
thomwolf's avatar
thomwolf committed
690
            Multiple choice classification loss.
691
        lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
692
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
693
694
695
696
        mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
697
            Can be used (see `past` input) to speed up sequential decoding.
698
699
700
701
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
702
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
703
704
705
706
707
708
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
thomwolf committed
709
710
711

    Examples::

712
        import torch
713
        from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
714

715
716
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
717

thomwolf's avatar
thomwolf committed
718
719
720
721
        # Add a [CLS] to the vocabulary (we should train it also!)
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary
722

thomwolf's avatar
thomwolf committed
723
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
thomwolf's avatar
thomwolf committed
724
725
726
727
728
729
730
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

        input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
731
        lm_prediction_scores, mc_prediction_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
732

733
        """
734
735
736
737
738
739
740
741
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
742
            use_cache=use_cache,
743
        )
744

thomwolf's avatar
thomwolf committed
745
        hidden_states = transformer_outputs[0]
746

thomwolf's avatar
thomwolf committed
747
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
748
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
749

750
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
751
752
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
753
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
754
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
755
        if lm_labels is not None:
756
757
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
758
            loss_fct = CrossEntropyLoss()
759
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
760
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
761
762

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)