modeling_gpt2.py 34.3 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

18

thomwolf's avatar
thomwolf committed
19
20
21
22
23
24
25
26
import logging
import math
import os

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

27
from .activations import gelu_new
28
from .configuration_gpt2 import GPT2Config
29
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
Aymeric Augustin's avatar
Aymeric Augustin committed
30
31
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer

thomwolf's avatar
thomwolf committed
32
33
34

logger = logging.getLogger(__name__)

35
36
37
38
39
40
41
42
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
    "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
    "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin",
    "gpt2-xl": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-xl-pytorch_model.bin",
    "distilgpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-pytorch_model.bin",
}

thomwolf's avatar
thomwolf committed
43

44
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
thomwolf's avatar
thomwolf committed
45
46
47
48
49
50
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import re
        import tensorflow as tf
    except ImportError:
51
52
53
54
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
55
56
        raise
    tf_path = os.path.abspath(gpt2_checkpoint_path)
thomwolf's avatar
thomwolf committed
57
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
thomwolf's avatar
thomwolf committed
58
59
60
61
62
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
63
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
thomwolf's avatar
thomwolf committed
64
65
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
thomwolf's avatar
thomwolf committed
66
        arrays.append(array.squeeze())
thomwolf's avatar
thomwolf committed
67
68

    for name, array in zip(names, arrays):
thomwolf's avatar
thomwolf committed
69
        name = name[6:]  # skip "model/"
70
        name = name.split("/")
thomwolf's avatar
thomwolf committed
71
72
        pointer = model
        for m_name in name:
73
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
74
                scope_names = re.split(r"(\d+)", m_name)
thomwolf's avatar
thomwolf committed
75
            else:
76
77
                scope_names = [m_name]
            if scope_names[0] == "w" or scope_names[0] == "g":
78
                pointer = getattr(pointer, "weight")
79
            elif scope_names[0] == "b":
80
                pointer = getattr(pointer, "bias")
81
82
            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
                pointer = getattr(pointer, scope_names[0])
83
                pointer = getattr(pointer, "weight")
thomwolf's avatar
thomwolf committed
84
            else:
85
86
87
                pointer = getattr(pointer, scope_names[0])
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
88
89
90
91
92
93
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
94
        logger.info("Initialize PyTorch weight {}".format(name))
thomwolf's avatar
thomwolf committed
95
96
97
98
99
        pointer.data = torch.from_numpy(array)
    return model


class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
100
    def __init__(self, nx, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
101
        super().__init__()
thomwolf's avatar
thomwolf committed
102
103
        self.output_attentions = config.output_attentions

thomwolf's avatar
thomwolf committed
104
105
106
107
108
109
110
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
111

thomwolf's avatar
thomwolf committed
112
113
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
114
115
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
116
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
117

118
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
119
120
        if len(heads) == 0:
            return
121
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
122
        heads = set(heads) - self.pruned_heads  # Convert to set and emove already pruned heads
123
        for head in heads:
124
125
            # Compute how many pruned heads are before the head and move the index accordingly
            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
126
127
128
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
129
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
130

131
132
133
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
134

135
136
137
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
138
        self.pruned_heads = self.pruned_heads.union(heads)
139

140
    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
141
142
143
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
thomwolf's avatar
thomwolf committed
144
        nd, ns = w.size(-2), w.size(-1)
145
        b = self.bias[:, :, ns - nd : ns, :ns]
146
        w = w * b - 1e4 * (1 - b)
thomwolf's avatar
thomwolf committed
147

148
149
150
151
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

thomwolf's avatar
thomwolf committed
152
        w = nn.Softmax(dim=-1)(w)
153
        w = self.attn_dropout(w)
154
155
156
157
158

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
159
        outputs = [torch.matmul(w, v)]
thomwolf's avatar
thomwolf committed
160
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
161
162
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
163
164
165
166
167
168
169
170
171
172

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
thomwolf's avatar
thomwolf committed
173
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
thomwolf's avatar
thomwolf committed
174
        else:
thomwolf's avatar
thomwolf committed
175
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
thomwolf's avatar
thomwolf committed
176

177
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
178
179
180
181
182
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
thomwolf's avatar
thomwolf committed
183
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
184
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
thomwolf's avatar
thomwolf committed
185
            key = torch.cat((past_key, key), dim=-1)
thomwolf's avatar
thomwolf committed
186
            value = torch.cat((past_value, value), dim=-2)
thomwolf's avatar
thomwolf committed
187
        present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
188

189
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
thomwolf's avatar
thomwolf committed
190
        a = attn_outputs[0]
191

thomwolf's avatar
thomwolf committed
192
193
        a = self.merge_heads(a)
        a = self.c_proj(a)
194
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
195
196
197

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)
thomwolf's avatar
thomwolf committed
198
199
200
201


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
Julien Chaumond's avatar
Julien Chaumond committed
202
        super().__init__()
thomwolf's avatar
thomwolf committed
203
204
205
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
206
        self.act = gelu_new
207
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
208
209
210
211

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
212
        return self.dropout(h2)
thomwolf's avatar
thomwolf committed
213
214
215


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
216
    def __init__(self, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
217
        super().__init__()
thomwolf's avatar
thomwolf committed
218
        nx = config.n_embd
219
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
220
        self.attn = Attention(nx, n_ctx, config, scale)
221
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
222
223
        self.mlp = MLP(4 * nx, config)

224
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
225
226
227
        output_attn = self.attn(
            self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
        )
thomwolf's avatar
thomwolf committed
228
229
        a = output_attn[0]  # output_attn: a, present, (attentions)

thomwolf's avatar
thomwolf committed
230
        x = x + a
thomwolf's avatar
thomwolf committed
231
        m = self.mlp(self.ln_2(x))
thomwolf's avatar
thomwolf committed
232
        x = x + m
thomwolf's avatar
thomwolf committed
233
234
235

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)
thomwolf's avatar
thomwolf committed
236
237


238
class GPT2PreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
239
    """ An abstract class to handle weights initialization and
240
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
241
    """
242

243
    config_class = GPT2Config
244
    pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
245
246
    load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
thomwolf's avatar
thomwolf committed
247

248
    def __init__(self, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
249
        super().__init__(*inputs, **kwargs)
250

251
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
252
253
        """ Initialize the weights.
        """
254
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
255
256
257
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
258
259
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
260
        elif isinstance(module, nn.LayerNorm):
thomwolf's avatar
thomwolf committed
261
262
263
264
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


Lysandre's avatar
Lysandre committed
265
266
267
268
GPT2_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
Lysandre's avatar
Fixes  
Lysandre committed
269
    usage and behavior.
thomwolf's avatar
thomwolf committed
270
271

    Parameters:
272
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
273
            Initializing with a config file does not load the weights associated with the model, only the configuration.
274
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
275
276
"""

Lysandre's avatar
Lysandre committed
277
GPT2_INPUTS_DOCSTRING = r"""
278
    Args:
279
280
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length if `past` is None else 1
Lysandre's avatar
Lysandre committed
281
            Indices of input sequence tokens in the vocabulary.
282
            If using `past` as an input make sure that `input_ids` are those of the last position.
Lysandre's avatar
Lysandre committed
283

284
285
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
286
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
Lysandre's avatar
Lysandre committed
287

288
            `What are input IDs? <../glossary.html#input-ids>`__
289

290
291
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
292
            (see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
293
            should not be passed as input ids as they have already been computed.
294
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
295
296
297
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
298

299
            `What are attention masks? <../glossary.html#attention-mask>`__
300
301
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`, defaults to :obj:`None`):
            `input_ids_length` = `sequence_length if `past` is None else 1
302
303
304
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
305
            If using `past` as an input make sure that `token_type_ids` correspond to the `input_ids` of the last position.
Lysandre's avatar
Lysandre committed
306

307
308
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
309
310
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
311

312
313
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
314
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
315
            Mask values selected in ``[0, 1]``:
316
317
318
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
        input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
319
320
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
thomwolf's avatar
thomwolf committed
321
322
"""

323
324
325
326
327

@add_start_docstrings(
    "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
328
class GPT2Model(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
329
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
330
        super().__init__(config)
thomwolf's avatar
thomwolf committed
331
332
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions
333
        self.output_past = config.output_past
thomwolf's avatar
thomwolf committed
334

thomwolf's avatar
thomwolf committed
335
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
thomwolf's avatar
thomwolf committed
336
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
337
        self.drop = nn.Dropout(config.embd_pdrop)
338
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
339
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
340

341
        self.init_weights()
thomwolf's avatar
thomwolf committed
342

thomwolf's avatar
thomwolf committed
343
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
344
        return self.wte
thomwolf's avatar
thomwolf committed
345

thomwolf's avatar
thomwolf committed
346
    def set_input_embeddings(self, new_embeddings):
347
348
        self.wte = new_embeddings

thomwolf's avatar
thomwolf committed
349
    def _prune_heads(self, heads_to_prune):
350
351
352
353
354
355
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

356
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
357
358
359
360
361
362
363
364
365
366
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
    ):
367
368
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
369
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the last layer of the model.
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

Lysandre's avatar
Lysandre committed
390
391
392
        from transformers import GPT2Tokenizer, GPT2Model
        import torch

393
394
395
396
397
398
399
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2Model.from_pretrained('gpt2')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

        """
Julien Chaumond's avatar
Julien Chaumond committed
400
401
402
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
403
404
405
406
407
408
409
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

410
411
412
413
414
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

thomwolf's avatar
thomwolf committed
415
        if past is None:
thomwolf's avatar
thomwolf committed
416
            past_length = 0
thomwolf's avatar
thomwolf committed
417
            past = [None] * len(self.h)
thomwolf's avatar
thomwolf committed
418
        else:
thomwolf's avatar
thomwolf committed
419
            past_length = past[0][0].size(-2)
thomwolf's avatar
thomwolf committed
420
        if position_ids is None:
421
422
423
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
thomwolf's avatar
thomwolf committed
424

425
426
        # Attention mask.
        if attention_mask is not None:
427
428
            batch_size = input_ids.shape[0]
            attention_mask = attention_mask.view(batch_size, -1)
429
430
431
432
433
434
435
436
437
438
439
440
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
441
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
442
443
            attention_mask = (1.0 - attention_mask) * -10000.0

444
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
445
        # 1.0 in head_mask indicate we keep the head
446
        # attention_probs has shape bsz x n_heads x N x N
447
        # head_mask has shape n_layer x batch x n_heads x N x N
448
449
        if head_mask is not None:
            if head_mask.dim() == 1:
450
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
451
                head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
452
            elif head_mask.dim() == 2:
453
454
455
456
457
458
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
459
460
        else:
            head_mask = [None] * self.config.n_layer
461

462
463
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
thomwolf's avatar
thomwolf committed
464
465
466
467
468
469
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
470
471
        hidden_states = self.drop(hidden_states)

472
473
        output_shape = input_shape + (hidden_states.size(-1),)

474
        presents = ()
thomwolf's avatar
thomwolf committed
475
        all_attentions = []
476
        all_hidden_states = ()
477
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
thomwolf's avatar
thomwolf committed
478
            if self.output_hidden_states:
479
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
480

481
482
483
            outputs = block(
                hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
            )
484

thomwolf's avatar
thomwolf committed
485
            hidden_states, present = outputs[:2]
486
487
            if self.output_past:
                presents = presents + (present,)
thomwolf's avatar
thomwolf committed
488
489
490
491

            if self.output_attentions:
                all_attentions.append(outputs[2])

thomwolf's avatar
thomwolf committed
492
        hidden_states = self.ln_f(hidden_states)
493

thomwolf's avatar
thomwolf committed
494
495
496
        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
        if self.output_hidden_states:
497
            all_hidden_states = all_hidden_states + (hidden_states,)
thomwolf's avatar
thomwolf committed
498

499
500
501
        outputs = (hidden_states,)
        if self.output_past:
            outputs = outputs + (presents,)
thomwolf's avatar
thomwolf committed
502
        if self.output_hidden_states:
503
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
504
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
505
506
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
507
            all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
508
            outputs = outputs + (all_attentions,)
509
        return outputs  # last hidden state, (presents), (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
510
511


512
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
513
    """The GPT2 Model transformer with a language modeling head on top
514
    (linear layer with weights tied to the input embeddings). """,
515
516
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
517
class GPT2LMHeadModel(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
518
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
519
        super().__init__(config)
thomwolf's avatar
thomwolf committed
520
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
521
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
522

523
        self.init_weights()
524

thomwolf's avatar
thomwolf committed
525
    def get_output_embeddings(self):
526
        return self.lm_head
thomwolf's avatar
thomwolf committed
527

528
    def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
529
        # only last token for inputs_ids if past is defined in kwargs
530
        if past:
531
            input_ids = input_ids[:, -1].unsqueeze(-1)
532

533
        return {"input_ids": input_ids, "past": past}
534

535
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
536
537
538
539
540
541
542
543
544
545
546
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
547
548
549
550
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Lysandre's avatar
Lysandre committed
551
552
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
            All labels set to ``-100`` are ignored (masked), the loss is only
553
554
555
            computed for labels in ``[0, ..., config.vocab_size]``

    Return:
Lysandre's avatar
Fixes  
Lysandre committed
556
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
            Language modeling loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        import torch
        from transformers import GPT2Tokenizer, GPT2LMHeadModel

        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2LMHeadModel.from_pretrained('gpt2')

        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]

        """
590
591
592
593
594
595
596
597
598
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
thomwolf's avatar
thomwolf committed
599
        hidden_states = transformer_outputs[0]
600

thomwolf's avatar
thomwolf committed
601
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
602

603
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
604
        if labels is not None:
605
            # Shift so that tokens < n predict n
606
            shift_logits = lm_logits[..., :-1, :].contiguous()
thomwolf's avatar
thomwolf committed
607
            shift_labels = labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
608
            # Flatten the tokens
LysandreJik's avatar
LysandreJik committed
609
            loss_fct = CrossEntropyLoss()
610
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
611
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
612
613

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
614
615


616
617
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
618
619
620
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
621
622
623
""",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
624
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
    def __init__(self, config):
        super().__init__(config)
        config.num_labels = 1
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.multiple_choice_head = SequenceSummary(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        lm_labels=None,
        mc_labels=None,
    ):
        r"""
        mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
thomwolf's avatar
thomwolf committed
653
654
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
655
        lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
656
657
658
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
659
            All labels set to ``-100`` are ignored (masked), the loss is only
thomwolf's avatar
thomwolf committed
660
            computed for labels in ``[0, ..., config.vocab_size]``
661
        mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
662
663
664
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
thomwolf's avatar
thomwolf committed
665

666
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
667
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
668
        lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``lm_labels`` is provided):
thomwolf's avatar
thomwolf committed
669
            Language modeling loss.
670
        mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`multiple_choice_labels` is provided):
thomwolf's avatar
thomwolf committed
671
            Multiple choice classification loss.
672
        lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
673
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
674
675
676
677
        mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
678
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
679
            should not be passed as input ids as they have already been computed.
680
681
682
683
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
684
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
685
686
687
688
689
690
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
thomwolf committed
691
692
693

    Examples::

694
        import torch
695
        from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
696

697
698
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
699

thomwolf's avatar
thomwolf committed
700
701
702
703
        # Add a [CLS] to the vocabulary (we should train it also!)
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})
        model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
        print(tokenizer.cls_token_id, len(tokenizer))  # The newly token the last token of the vocabulary
704

thomwolf's avatar
thomwolf committed
705
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
thomwolf's avatar
thomwolf committed
706
707
708
709
710
711
712
        encoded_choices = [tokenizer.encode(s) for s in choices]
        cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]

        input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1

        outputs = model(input_ids, mc_token_ids=mc_token_ids)
713
        lm_prediction_scores, mc_prediction_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
714

715
        """
716
717
718
719
720
721
722
723
724
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
725

thomwolf's avatar
thomwolf committed
726
        hidden_states = transformer_outputs[0]
727

thomwolf's avatar
thomwolf committed
728
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
729
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
730

731
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
732
733
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
734
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
735
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
736
        if lm_labels is not None:
737
738
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
739
            loss_fct = CrossEntropyLoss()
740
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
741
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
742
743

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)