modeling_gpt2.py 34.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

18

thomwolf's avatar
thomwolf committed
19
20
21
22
23
24
25
import logging
import os

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

26
from .activations import ACT2FN
27
from .configuration_gpt2 import GPT2Config
28
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
Aymeric Augustin's avatar
Aymeric Augustin committed
29
30
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer

thomwolf's avatar
thomwolf committed
31
32
33

logger = logging.getLogger(__name__)

34
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
Julien Chaumond's avatar
Julien Chaumond committed
35
36
37
38
39
    "gpt2": "https://cdn.huggingface.co/gpt2-pytorch_model.bin",
    "gpt2-medium": "https://cdn.huggingface.co/gpt2-medium-pytorch_model.bin",
    "gpt2-large": "https://cdn.huggingface.co/gpt2-large-pytorch_model.bin",
    "gpt2-xl": "https://cdn.huggingface.co/gpt2-xl-pytorch_model.bin",
    "distilgpt2": "https://cdn.huggingface.co/distilgpt2-pytorch_model.bin",
40
41
}

thomwolf's avatar
thomwolf committed
42

43
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
thomwolf's avatar
thomwolf committed
44
45
46
47
48
49
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import re
        import tensorflow as tf
    except ImportError:
50
51
52
53
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
54
55
        raise
    tf_path = os.path.abspath(gpt2_checkpoint_path)
thomwolf's avatar
thomwolf committed
56
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
thomwolf's avatar
thomwolf committed
57
58
59
60
61
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
62
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
thomwolf's avatar
thomwolf committed
63
64
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
thomwolf's avatar
thomwolf committed
65
        arrays.append(array.squeeze())
thomwolf's avatar
thomwolf committed
66
67

    for name, array in zip(names, arrays):
thomwolf's avatar
thomwolf committed
68
        name = name[6:]  # skip "model/"
69
        name = name.split("/")
thomwolf's avatar
thomwolf committed
70
71
        pointer = model
        for m_name in name:
72
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
73
                scope_names = re.split(r"(\d+)", m_name)
thomwolf's avatar
thomwolf committed
74
            else:
75
76
                scope_names = [m_name]
            if scope_names[0] == "w" or scope_names[0] == "g":
77
                pointer = getattr(pointer, "weight")
78
            elif scope_names[0] == "b":
79
                pointer = getattr(pointer, "bias")
80
81
            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
                pointer = getattr(pointer, scope_names[0])
82
                pointer = getattr(pointer, "weight")
thomwolf's avatar
thomwolf committed
83
            else:
84
85
86
                pointer = getattr(pointer, scope_names[0])
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
87
88
89
90
91
92
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
93
        logger.info("Initialize PyTorch weight {}".format(name))
thomwolf's avatar
thomwolf committed
94
95
96
97
98
        pointer.data = torch.from_numpy(array)
    return model


class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
99
    def __init__(self, nx, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
100
        super().__init__()
thomwolf's avatar
thomwolf committed
101
102
        self.output_attentions = config.output_attentions

thomwolf's avatar
thomwolf committed
103
104
105
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
106
107
108
109
        self.register_buffer(
            "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))
thomwolf's avatar
thomwolf committed
110
111
112
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
113

thomwolf's avatar
thomwolf committed
114
115
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
116
117
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
118
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
119

120
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
121
122
        if len(heads) == 0:
            return
123
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
124
        heads = set(heads) - self.pruned_heads  # Convert to set and emove already pruned heads
125
        for head in heads:
126
127
            # Compute how many pruned heads are before the head and move the index accordingly
            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
128
129
130
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
131
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
132

133
134
135
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
136

137
138
139
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
140
        self.pruned_heads = self.pruned_heads.union(heads)
141

142
    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
143
144
        w = torch.matmul(q, k)
        if self.scale:
145
            w = w / (float(v.size(-1)) ** 0.5)
thomwolf's avatar
thomwolf committed
146
        nd, ns = w.size(-2), w.size(-1)
147
        mask = self.bias[:, :, ns - nd : ns, :ns]
148
        w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
thomwolf's avatar
thomwolf committed
149

150
151
152
153
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

thomwolf's avatar
thomwolf committed
154
        w = nn.Softmax(dim=-1)(w)
155
        w = self.attn_dropout(w)
156
157
158
159
160

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
161
        outputs = [torch.matmul(w, v)]
thomwolf's avatar
thomwolf committed
162
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
163
164
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
165
166
167
168
169
170
171
172
173
174

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
thomwolf's avatar
thomwolf committed
175
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
thomwolf's avatar
thomwolf committed
176
        else:
thomwolf's avatar
thomwolf committed
177
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
thomwolf's avatar
thomwolf committed
178

179
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
thomwolf's avatar
thomwolf committed
180
181
182
183
184
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
thomwolf's avatar
thomwolf committed
185
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
186
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
thomwolf's avatar
thomwolf committed
187
            key = torch.cat((past_key, key), dim=-1)
thomwolf's avatar
thomwolf committed
188
            value = torch.cat((past_value, value), dim=-2)
189
190
191
192
193

        if use_cache is True:
            present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
        else:
            present = (None,)
194

195
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
thomwolf's avatar
thomwolf committed
196
        a = attn_outputs[0]
197

thomwolf's avatar
thomwolf committed
198
199
        a = self.merge_heads(a)
        a = self.c_proj(a)
200
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
201
202
203

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)
thomwolf's avatar
thomwolf committed
204
205
206
207


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
Julien Chaumond's avatar
Julien Chaumond committed
208
        super().__init__()
thomwolf's avatar
thomwolf committed
209
210
211
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
212
        self.act = ACT2FN[config.activation_function]
213
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
214
215
216
217

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
218
        return self.dropout(h2)
thomwolf's avatar
thomwolf committed
219
220
221


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
222
    def __init__(self, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
223
        super().__init__()
thomwolf's avatar
thomwolf committed
224
        nx = config.n_embd
225
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
226
        self.attn = Attention(nx, n_ctx, config, scale)
227
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
228
229
        self.mlp = MLP(4 * nx, config)

230
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
231
        output_attn = self.attn(
232
233
234
235
236
            self.ln_1(x),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
237
        )
thomwolf's avatar
thomwolf committed
238
239
        a = output_attn[0]  # output_attn: a, present, (attentions)

thomwolf's avatar
thomwolf committed
240
        x = x + a
thomwolf's avatar
thomwolf committed
241
        m = self.mlp(self.ln_2(x))
thomwolf's avatar
thomwolf committed
242
        x = x + m
thomwolf's avatar
thomwolf committed
243
244
245

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)
thomwolf's avatar
thomwolf committed
246
247


248
class GPT2PreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
249
    """ An abstract class to handle weights initialization and
250
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
251
    """
252

253
    config_class = GPT2Config
254
    pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
255
256
    load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
thomwolf's avatar
thomwolf committed
257

258
    def __init__(self, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
259
        super().__init__(*inputs, **kwargs)
260

261
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
262
263
        """ Initialize the weights.
        """
264
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
265
266
267
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
268
269
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
270
        elif isinstance(module, nn.LayerNorm):
thomwolf's avatar
thomwolf committed
271
272
273
274
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


Lysandre's avatar
Lysandre committed
275
276
277
278
GPT2_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
Lysandre's avatar
Fixes  
Lysandre committed
279
    usage and behavior.
thomwolf's avatar
thomwolf committed
280
281

    Parameters:
282
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
283
            Initializing with a config file does not load the weights associated with the model, only the configuration.
284
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
285
286
"""

Lysandre's avatar
Lysandre committed
287
GPT2_INPUTS_DOCSTRING = r"""
288
    Args:
289
290
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
            :obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
Lysandre's avatar
Lysandre committed
291
            Indices of input sequence tokens in the vocabulary.
292
293

            If `past` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`.
Lysandre's avatar
Lysandre committed
294

295
296
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
297
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
Lysandre's avatar
Lysandre committed
298

299
            `What are input IDs? <../glossary.html#input-ids>`__
300

301
302
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
303
            (see `past` output below). Can be used to speed up sequential decoding.
304
            The `input_ids` which have their past given to this model should not be passed as `input_ids` as they have already been computed.
305
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
306
307
308
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
309

310
            `What are attention masks? <../glossary.html#attention-mask>`__
311
312
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`, defaults to :obj:`None`):
            `input_ids_length` = `sequence_length if `past` is None else 1
313
314
315
316
317
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
318
319
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
320

321
322
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
323
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
324
            Mask values selected in ``[0, 1]``:
325
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
flozi00's avatar
flozi00 committed
326
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
327
328
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
flozi00's avatar
flozi00 committed
329
            If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`).
330
331
        use_cache (:obj:`bool`):
            If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
thomwolf's avatar
thomwolf committed
332
333
"""

334
335
336
337
338

@add_start_docstrings(
    "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
339
class GPT2Model(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
340
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
341
        super().__init__(config)
thomwolf's avatar
thomwolf committed
342
343
344
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions

thomwolf's avatar
thomwolf committed
345
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
thomwolf's avatar
thomwolf committed
346
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
347
        self.drop = nn.Dropout(config.embd_pdrop)
348
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
349
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
350

351
        self.init_weights()
thomwolf's avatar
thomwolf committed
352

thomwolf's avatar
thomwolf committed
353
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
354
        return self.wte
thomwolf's avatar
thomwolf committed
355

thomwolf's avatar
thomwolf committed
356
    def set_input_embeddings(self, new_embeddings):
357
358
        self.wte = new_embeddings

thomwolf's avatar
thomwolf committed
359
    def _prune_heads(self, heads_to_prune):
360
361
362
363
364
365
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

366
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
367
368
369
370
371
372
373
374
375
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
376
        use_cache=True,
377
    ):
378
379
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
380
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
381
382
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the last layer of the model.
383
            If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
384
385
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
386
            Can be used (see `past` input) to speed up sequential decoding.
387
388
389
390
391
392
393
394
395
396
397
398
399
400
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

Lysandre's avatar
Lysandre committed
401
402
403
        from transformers import GPT2Tokenizer, GPT2Model
        import torch

404
405
406
407
408
409
410
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2Model.from_pretrained('gpt2')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

        """
411

Julien Chaumond's avatar
Julien Chaumond committed
412
413
414
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
415
416
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
417
            batch_size = input_ids.shape[0]
418
419
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
420
            batch_size = inputs_embeds.shape[0]
421
422
423
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

424
425
426
427
428
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

thomwolf's avatar
thomwolf committed
429
        if past is None:
thomwolf's avatar
thomwolf committed
430
            past_length = 0
thomwolf's avatar
thomwolf committed
431
            past = [None] * len(self.h)
thomwolf's avatar
thomwolf committed
432
        else:
thomwolf's avatar
thomwolf committed
433
            past_length = past[0][0].size(-2)
thomwolf's avatar
thomwolf committed
434
        if position_ids is None:
435
436
437
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
thomwolf's avatar
thomwolf committed
438

439
440
        # Attention mask.
        if attention_mask is not None:
441
            assert batch_size > 0, "batch_size has to be defined and > 0"
442
            attention_mask = attention_mask.view(batch_size, -1)
443
444
445
446
447
448
449
450
451
452
453
454
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
455
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
456
457
            attention_mask = (1.0 - attention_mask) * -10000.0

458
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
459
        # 1.0 in head_mask indicate we keep the head
460
        # attention_probs has shape bsz x n_heads x N x N
461
        # head_mask has shape n_layer x batch x n_heads x N x N
462
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
463

464
465
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
thomwolf's avatar
thomwolf committed
466
467
468
469
470
471
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
472
473
        hidden_states = self.drop(hidden_states)

474
475
        output_shape = input_shape + (hidden_states.size(-1),)

476
        presents = ()
thomwolf's avatar
thomwolf committed
477
        all_attentions = []
478
        all_hidden_states = ()
479
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
thomwolf's avatar
thomwolf committed
480
            if self.output_hidden_states:
481
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
482

483
            outputs = block(
484
485
486
487
488
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
489
            )
490

thomwolf's avatar
thomwolf committed
491
            hidden_states, present = outputs[:2]
492
            if use_cache is True:
493
                presents = presents + (present,)
thomwolf's avatar
thomwolf committed
494
495
496
497

            if self.output_attentions:
                all_attentions.append(outputs[2])

thomwolf's avatar
thomwolf committed
498
        hidden_states = self.ln_f(hidden_states)
499

thomwolf's avatar
thomwolf committed
500
501
502
        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
        if self.output_hidden_states:
503
            all_hidden_states = all_hidden_states + (hidden_states,)
thomwolf's avatar
thomwolf committed
504

505
        outputs = (hidden_states,)
506
        if use_cache is True:
507
            outputs = outputs + (presents,)
thomwolf's avatar
thomwolf committed
508
        if self.output_hidden_states:
509
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
510
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
511
512
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
513
            all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
514
            outputs = outputs + (all_attentions,)
515
        return outputs  # last hidden state, (presents), (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
516
517


518
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
519
    """The GPT2 Model transformer with a language modeling head on top
520
    (linear layer with weights tied to the input embeddings). """,
521
522
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
523
class GPT2LMHeadModel(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
524
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
525
        super().__init__(config)
thomwolf's avatar
thomwolf committed
526
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
527
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
528

529
        self.init_weights()
530

thomwolf's avatar
thomwolf committed
531
    def get_output_embeddings(self):
532
        return self.lm_head
thomwolf's avatar
thomwolf committed
533

534
    def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
535
        # only last token for inputs_ids if past is defined in kwargs
536
        if past:
537
            input_ids = input_ids[:, -1].unsqueeze(-1)
538

539
        return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
540

541
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
542
543
544
545
546
547
548
549
550
551
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
552
        use_cache=True,
553
    ):
554
555
556
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for language modeling.
557
            Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Lysandre's avatar
Lysandre committed
558
559
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
            All labels set to ``-100`` are ignored (masked), the loss is only
560
561
562
            computed for labels in ``[0, ..., config.vocab_size]``

    Return:
Lysandre's avatar
Fixes  
Lysandre committed
563
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
564
565
566
567
568
569
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
            Language modeling loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
570
            Can be used (see `past` input) to speed up sequential decoding.
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        import torch
        from transformers import GPT2Tokenizer, GPT2LMHeadModel

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2LMHeadModel.from_pretrained('gpt2')

        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]

        """
596
597
598
599
600
601
602
603
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
604
            use_cache=use_cache,
605
        )
thomwolf's avatar
thomwolf committed
606
        hidden_states = transformer_outputs[0]
607

thomwolf's avatar
thomwolf committed
608
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
609

610
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
611
        if labels is not None:
612
            # Shift so that tokens < n predict n
613
            shift_logits = lm_logits[..., :-1, :].contiguous()
thomwolf's avatar
thomwolf committed
614
            shift_labels = labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
615
            # Flatten the tokens
LysandreJik's avatar
LysandreJik committed
616
            loss_fct = CrossEntropyLoss()
617
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
618
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
619
620

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
621
622


623
624
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
625
626
627
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
628
629
630
""",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
631
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
    def __init__(self, config):
        super().__init__(config)
        config.num_labels = 1
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.multiple_choice_head = SequenceSummary(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        lm_labels=None,
        mc_labels=None,
657
        use_cache=True,
658
659
660
    ):
        r"""
        mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
thomwolf's avatar
thomwolf committed
661
662
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
663
        lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
664
665
666
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
667
            All labels set to ``-100`` are ignored (masked), the loss is only
thomwolf's avatar
thomwolf committed
668
            computed for labels in ``[0, ..., config.vocab_size]``
669
        mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
670
671
672
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
thomwolf's avatar
thomwolf committed
673

674
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
675
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
676
        lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``lm_labels`` is provided):
thomwolf's avatar
thomwolf committed
677
            Language modeling loss.
678
        mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`multiple_choice_labels` is provided):
thomwolf's avatar
thomwolf committed
679
            Multiple choice classification loss.
680
        lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
681
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
682
683
684
685
        mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
686
            Can be used (see `past` input) to speed up sequential decoding.
687
688
689
690
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
691
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
692
693
694
695
696
697
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
thomwolf committed
698
699
700

    Examples::

701
        import torch
702
        from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
703

704
705
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
706

thomwolf's avatar
thomwolf committed
707
708
709
710
        # Add a [CLS] to the vocabulary (we should train it also!)
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary
711

thomwolf's avatar
thomwolf committed
712
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
thomwolf's avatar
thomwolf committed
713
714
715
716
717
718
719
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

        input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
720
        lm_prediction_scores, mc_prediction_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
721

722
        """
723
724
725
726
727
728
729
730
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
731
            use_cache=use_cache,
732
        )
733

thomwolf's avatar
thomwolf committed
734
        hidden_states = transformer_outputs[0]
735

thomwolf's avatar
thomwolf committed
736
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
737
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
738

739
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
740
741
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
742
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
743
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
744
        if lm_labels is not None:
745
746
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
747
            loss_fct = CrossEntropyLoss()
748
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
749
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
750
751

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)