modeling_gpt2.py 34.6 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

18

thomwolf's avatar
thomwolf committed
19
20
import logging
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
21
import warnings
thomwolf's avatar
thomwolf committed
22
23
24
25
26

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

27
from .activations import ACT2FN
28
from .configuration_gpt2 import GPT2Config
29
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
30
31
32
33
34
35
36
from .modeling_utils import (
    Conv1D,
    PreTrainedModel,
    SequenceSummary,
    find_pruneable_heads_and_indices,
    prune_conv1d_layer,
)
Aymeric Augustin's avatar
Aymeric Augustin committed
37

thomwolf's avatar
thomwolf committed
38
39
40

logger = logging.getLogger(__name__)

41
42
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"

43
44
45
46
47
48
49
50
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "gpt2",
    "gpt2-medium",
    "gpt2-large",
    "gpt2-xl",
    "distilgpt2",
    # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
]
51

thomwolf's avatar
thomwolf committed
52

53
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
thomwolf's avatar
thomwolf committed
54
55
56
57
58
59
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import re
        import tensorflow as tf
    except ImportError:
60
61
62
63
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
64
65
        raise
    tf_path = os.path.abspath(gpt2_checkpoint_path)
thomwolf's avatar
thomwolf committed
66
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
thomwolf's avatar
thomwolf committed
67
68
69
70
71
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
72
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
thomwolf's avatar
thomwolf committed
73
74
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
thomwolf's avatar
thomwolf committed
75
        arrays.append(array.squeeze())
thomwolf's avatar
thomwolf committed
76
77

    for name, array in zip(names, arrays):
thomwolf's avatar
thomwolf committed
78
        name = name[6:]  # skip "model/"
79
        name = name.split("/")
thomwolf's avatar
thomwolf committed
80
81
        pointer = model
        for m_name in name:
82
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
83
                scope_names = re.split(r"(\d+)", m_name)
thomwolf's avatar
thomwolf committed
84
            else:
85
86
                scope_names = [m_name]
            if scope_names[0] == "w" or scope_names[0] == "g":
87
                pointer = getattr(pointer, "weight")
88
            elif scope_names[0] == "b":
89
                pointer = getattr(pointer, "bias")
90
91
            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
                pointer = getattr(pointer, scope_names[0])
92
                pointer = getattr(pointer, "weight")
thomwolf's avatar
thomwolf committed
93
            else:
94
95
96
                pointer = getattr(pointer, scope_names[0])
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
97
98
99
100
101
102
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
103
        logger.info("Initialize PyTorch weight {}".format(name))
thomwolf's avatar
thomwolf committed
104
105
106
107
108
        pointer.data = torch.from_numpy(array)
    return model


class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
109
    def __init__(self, nx, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
110
        super().__init__()
thomwolf's avatar
thomwolf committed
111

thomwolf's avatar
thomwolf committed
112
113
114
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
115
116
117
118
        self.register_buffer(
            "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))
thomwolf's avatar
thomwolf committed
119
120
121
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
122

thomwolf's avatar
thomwolf committed
123
124
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
125
126
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
127
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
128

129
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
130
131
        if len(heads) == 0:
            return
132
133
134
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
        )
135
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
136

137
138
139
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
140

141
142
143
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
144
        self.pruned_heads = self.pruned_heads.union(heads)
145

146
    def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
thomwolf's avatar
thomwolf committed
147
148
        w = torch.matmul(q, k)
        if self.scale:
149
            w = w / (float(v.size(-1)) ** 0.5)
thomwolf's avatar
thomwolf committed
150
        nd, ns = w.size(-2), w.size(-1)
151
        mask = self.bias[:, :, ns - nd : ns, :ns]
152
        w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
thomwolf's avatar
thomwolf committed
153

154
155
156
157
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

thomwolf's avatar
thomwolf committed
158
        w = nn.Softmax(dim=-1)(w)
159
        w = self.attn_dropout(w)
160
161
162
163
164

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
165
        outputs = [torch.matmul(w, v)]
166
        if output_attentions:
thomwolf's avatar
thomwolf committed
167
168
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
169
170
171
172
173
174
175
176
177
178

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
thomwolf's avatar
thomwolf committed
179
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
thomwolf's avatar
thomwolf committed
180
        else:
thomwolf's avatar
thomwolf committed
181
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
thomwolf's avatar
thomwolf committed
182

183
184
185
    def forward(
        self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
    ):
thomwolf's avatar
thomwolf committed
186
187
188
189
190
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
thomwolf's avatar
thomwolf committed
191
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
192
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
thomwolf's avatar
thomwolf committed
193
            key = torch.cat((past_key, key), dim=-1)
thomwolf's avatar
thomwolf committed
194
            value = torch.cat((past_value, value), dim=-2)
195
196
197
198
199

        if use_cache is True:
            present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
        else:
            present = (None,)
200

201
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
thomwolf's avatar
thomwolf committed
202
        a = attn_outputs[0]
203

thomwolf's avatar
thomwolf committed
204
205
        a = self.merge_heads(a)
        a = self.c_proj(a)
206
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
207
208
209

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)
thomwolf's avatar
thomwolf committed
210
211
212
213


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
Julien Chaumond's avatar
Julien Chaumond committed
214
        super().__init__()
thomwolf's avatar
thomwolf committed
215
216
217
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
218
        self.act = ACT2FN[config.activation_function]
219
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
220
221
222
223

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
224
        return self.dropout(h2)
thomwolf's avatar
thomwolf committed
225
226
227


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
228
    def __init__(self, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
229
        super().__init__()
thomwolf's avatar
thomwolf committed
230
        nx = config.n_embd
231
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
232
        self.attn = Attention(nx, n_ctx, config, scale)
233
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
234
235
        self.mlp = MLP(4 * nx, config)

236
237
238
    def forward(
        self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False,
    ):
239
        output_attn = self.attn(
240
241
242
243
244
            self.ln_1(x),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
245
            output_attentions=output_attentions,
246
        )
thomwolf's avatar
thomwolf committed
247
248
        a = output_attn[0]  # output_attn: a, present, (attentions)

thomwolf's avatar
thomwolf committed
249
        x = x + a
thomwolf's avatar
thomwolf committed
250
        m = self.mlp(self.ln_2(x))
thomwolf's avatar
thomwolf committed
251
        x = x + m
thomwolf's avatar
thomwolf committed
252
253
254

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)
thomwolf's avatar
thomwolf committed
255
256


257
class GPT2PreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
258
    """ An abstract class to handle weights initialization and
259
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
260
    """
261

262
263
264
    config_class = GPT2Config
    load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
thomwolf's avatar
thomwolf committed
265

266
    def __init__(self, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
267
        super().__init__(*inputs, **kwargs)
268

269
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
270
271
        """ Initialize the weights.
        """
272
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
273
274
275
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
276
277
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
278
        elif isinstance(module, nn.LayerNorm):
thomwolf's avatar
thomwolf committed
279
280
281
282
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


Lysandre's avatar
Lysandre committed
283
284
285
286
GPT2_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
Lysandre's avatar
Fixes  
Lysandre committed
287
    usage and behavior.
thomwolf's avatar
thomwolf committed
288
289

    Parameters:
290
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
291
            Initializing with a config file does not load the weights associated with the model, only the configuration.
292
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
293
294
"""

Lysandre's avatar
Lysandre committed
295
GPT2_INPUTS_DOCSTRING = r"""
296
    Args:
297
298
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
            :obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
Lysandre's avatar
Lysandre committed
299
            Indices of input sequence tokens in the vocabulary.
300
301

            If `past` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`.
Lysandre's avatar
Lysandre committed
302

303
304
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
305
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
Lysandre's avatar
Lysandre committed
306

307
            `What are input IDs? <../glossary.html#input-ids>`__
308

309
310
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
311
            (see `past` output below). Can be used to speed up sequential decoding.
312
            The `input_ids` which have their past given to this model should not be passed as `input_ids` as they have already been computed.
313
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
314
315
316
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
317

318
            `What are attention masks? <../glossary.html#attention-mask>`__
319
320
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`, defaults to :obj:`None`):
            `input_ids_length` = `sequence_length if `past` is None else 1
321
322
323
324
325
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
326
327
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
328

329
330
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
331
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
332
            Mask values selected in ``[0, 1]``:
333
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
flozi00's avatar
flozi00 committed
334
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
335
336
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
flozi00's avatar
flozi00 committed
337
            If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`).
338
339
        use_cache (:obj:`bool`):
            If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
ZhuBaohe's avatar
ZhuBaohe committed
340
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
341
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
thomwolf's avatar
thomwolf committed
342
343
"""

344
345
346
347
348

@add_start_docstrings(
    "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
349
class GPT2Model(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
350
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
351
        super().__init__(config)
thomwolf's avatar
thomwolf committed
352

thomwolf's avatar
thomwolf committed
353
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
thomwolf's avatar
thomwolf committed
354
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
355
        self.drop = nn.Dropout(config.embd_pdrop)
356
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
357
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
358

359
        self.init_weights()
thomwolf's avatar
thomwolf committed
360

thomwolf's avatar
thomwolf committed
361
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
362
        return self.wte
thomwolf's avatar
thomwolf committed
363

thomwolf's avatar
thomwolf committed
364
    def set_input_embeddings(self, new_embeddings):
365
366
        self.wte = new_embeddings

thomwolf's avatar
thomwolf committed
367
    def _prune_heads(self, heads_to_prune):
368
369
370
371
372
373
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

374
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
375
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2")
376
377
378
379
380
381
382
383
384
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
385
        use_cache=None,
386
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
387
        output_hidden_states=None,
388
    ):
389
390
        r"""
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
391
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
392
393
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the last layer of the model.
394
            If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
395
396
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
397
            Can be used (see `past` input) to speed up sequential decoding.
Joseph Liu's avatar
Joseph Liu committed
398
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``) is passed or when ``config.output_hidden_states=True``:
399
400
401
402
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
403
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
404
405
406
407
408
409
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        """
410
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
411
412
413
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
414
        use_cache = use_cache if use_cache is not None else self.config.use_cache
415

Julien Chaumond's avatar
Julien Chaumond committed
416
417
418
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
419
420
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
421
            batch_size = input_ids.shape[0]
422
423
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
424
            batch_size = inputs_embeds.shape[0]
425
426
427
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

428
429
430
431
432
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

thomwolf's avatar
thomwolf committed
433
        if past is None:
thomwolf's avatar
thomwolf committed
434
            past_length = 0
thomwolf's avatar
thomwolf committed
435
            past = [None] * len(self.h)
thomwolf's avatar
thomwolf committed
436
        else:
thomwolf's avatar
thomwolf committed
437
            past_length = past[0][0].size(-2)
thomwolf's avatar
thomwolf committed
438
        if position_ids is None:
439
440
441
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
thomwolf's avatar
thomwolf committed
442

443
444
        # Attention mask.
        if attention_mask is not None:
445
            assert batch_size > 0, "batch_size has to be defined and > 0"
446
            attention_mask = attention_mask.view(batch_size, -1)
447
448
449
450
451
452
453
454
455
456
457
458
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
459
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
460
461
            attention_mask = (1.0 - attention_mask) * -10000.0

462
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
463
        # 1.0 in head_mask indicate we keep the head
464
        # attention_probs has shape bsz x n_heads x N x N
465
        # head_mask has shape n_layer x batch x n_heads x N x N
466
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
467

468
469
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
thomwolf's avatar
thomwolf committed
470
471
472
473
474
475
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
476
477
        hidden_states = self.drop(hidden_states)

478
479
        output_shape = input_shape + (hidden_states.size(-1),)

480
        presents = ()
thomwolf's avatar
thomwolf committed
481
        all_attentions = []
482
        all_hidden_states = ()
483
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
Joseph Liu's avatar
Joseph Liu committed
484
            if output_hidden_states:
485
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
486

487
            outputs = block(
488
489
490
491
492
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
493
                output_attentions=output_attentions,
494
            )
495

thomwolf's avatar
thomwolf committed
496
            hidden_states, present = outputs[:2]
497
            if use_cache is True:
498
                presents = presents + (present,)
thomwolf's avatar
thomwolf committed
499

500
            if output_attentions:
thomwolf's avatar
thomwolf committed
501
502
                all_attentions.append(outputs[2])

thomwolf's avatar
thomwolf committed
503
        hidden_states = self.ln_f(hidden_states)
504

thomwolf's avatar
thomwolf committed
505
506
        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
Joseph Liu's avatar
Joseph Liu committed
507
        if output_hidden_states:
508
            all_hidden_states = all_hidden_states + (hidden_states,)
thomwolf's avatar
thomwolf committed
509

510
        outputs = (hidden_states,)
511
        if use_cache is True:
512
            outputs = outputs + (presents,)
Joseph Liu's avatar
Joseph Liu committed
513
        if output_hidden_states:
514
            outputs = outputs + (all_hidden_states,)
515
        if output_attentions:
thomwolf's avatar
thomwolf committed
516
517
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
518
            all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
519
            outputs = outputs + (all_attentions,)
520
        return outputs  # last hidden state, (presents), (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
521
522


523
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
524
    """The GPT2 Model transformer with a language modeling head on top
525
    (linear layer with weights tied to the input embeddings). """,
526
527
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
528
class GPT2LMHeadModel(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
529
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
530
        super().__init__(config)
thomwolf's avatar
thomwolf committed
531
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
532
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
533

534
        self.init_weights()
535

thomwolf's avatar
thomwolf committed
536
    def get_output_embeddings(self):
537
        return self.lm_head
thomwolf's avatar
thomwolf committed
538

539
    def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
540
        # only last token for inputs_ids if past is defined in kwargs
541
        if past:
542
            input_ids = input_ids[:, -1].unsqueeze(-1)
543

544
        return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
545

546
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
547
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2")
548
549
550
551
552
553
554
555
556
557
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
558
        use_cache=None,
559
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
560
        output_hidden_states=None,
561
    ):
562
563
564
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for language modeling.
565
            Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Lysandre's avatar
Lysandre committed
566
567
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
            All labels set to ``-100`` are ignored (masked), the loss is only
568
569
570
            computed for labels in ``[0, ..., config.vocab_size]``

    Return:
Lysandre's avatar
Fixes  
Lysandre committed
571
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
572
573
574
575
576
577
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
            Language modeling loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
578
            Can be used (see `past` input) to speed up sequential decoding.
Joseph Liu's avatar
Joseph Liu committed
579
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
580
581
582
583
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
584
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
585
586
587
588
589
590
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        """
591
592
593
594
595
596
597
598
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
599
            use_cache=use_cache,
600
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
601
            output_hidden_states=output_hidden_states,
602
        )
thomwolf's avatar
thomwolf committed
603
        hidden_states = transformer_outputs[0]
604

thomwolf's avatar
thomwolf committed
605
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
606

607
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
608
        if labels is not None:
609
            # Shift so that tokens < n predict n
610
            shift_logits = lm_logits[..., :-1, :].contiguous()
thomwolf's avatar
thomwolf committed
611
            shift_labels = labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
612
            # Flatten the tokens
LysandreJik's avatar
LysandreJik committed
613
            loss_fct = CrossEntropyLoss()
614
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
615
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
616
617

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
618
619


620
621
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
622
623
624
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
625
626
627
""",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
628
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
    def __init__(self, config):
        super().__init__(config)
        config.num_labels = 1
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.multiple_choice_head = SequenceSummary(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
652
        labels=None,
653
        mc_labels=None,
654
        use_cache=None,
655
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
656
        output_hidden_states=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
657
        **kwargs
658
659
660
    ):
        r"""
        mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
thomwolf's avatar
thomwolf committed
661
662
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
Sylvain Gugger's avatar
Sylvain Gugger committed
663
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
664
            Labels for language modeling.
Sylvain Gugger's avatar
Sylvain Gugger committed
665
            Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
thomwolf's avatar
thomwolf committed
666
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
667
            All labels set to ``-100`` are ignored (masked), the loss is only
thomwolf's avatar
thomwolf committed
668
            computed for labels in ``[0, ..., config.vocab_size]``
669
        mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
670
671
672
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
Sylvain Gugger's avatar
Sylvain Gugger committed
673
674
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
thomwolf's avatar
thomwolf committed
675

676
    Return:
Lysandre's avatar
Fixes  
Lysandre committed
677
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
Sylvain Gugger's avatar
Sylvain Gugger committed
678
        lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
thomwolf's avatar
thomwolf committed
679
            Language modeling loss.
Sylvain Gugger's avatar
Sylvain Gugger committed
680
        mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
thomwolf's avatar
thomwolf committed
681
            Multiple choice classification loss.
682
        lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
683
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
684
685
686
687
        mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
688
            Can be used (see `past` input) to speed up sequential decoding.
Joseph Liu's avatar
Joseph Liu committed
689
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
690
691
692
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
693
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
694
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
695
696
697
698
699
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
thomwolf committed
700
701
702

    Examples::

703
704
705
706
707
        >>> import torch
        >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel

        >>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        >>> model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
708

709
710
        >>> # Add a [CLS] to the vocabulary (we should train it also!)
        >>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'})
711

712
        >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
713

714
715
716
        >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
        >>> encoded_choices = [tokenizer.encode(s) for s in choices]
        >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
thomwolf's avatar
thomwolf committed
717

718
719
        >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        >>> mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1
thomwolf's avatar
thomwolf committed
720

721
722
        >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
        >>> lm_prediction_scores, mc_prediction_scores = outputs[:2]
thomwolf's avatar
thomwolf committed
723

724
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
725
726
727
728
729
730
731
732
        if "lm_labels" in kwargs:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
                DeprecationWarning,
            )
            labels = kwargs.pop("lm_labels")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

733
734
735
736
737
738
739
740
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
741
            use_cache=use_cache,
742
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
743
            output_hidden_states=output_hidden_states,
744
        )
745

thomwolf's avatar
thomwolf committed
746
        hidden_states = transformer_outputs[0]
747

thomwolf's avatar
thomwolf committed
748
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
749
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
750

751
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
752
753
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
754
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
755
            outputs = (loss,) + outputs
Sylvain Gugger's avatar
Sylvain Gugger committed
756
        if labels is not None:
757
            shift_logits = lm_logits[..., :-1, :].contiguous()
Sylvain Gugger's avatar
Sylvain Gugger committed
758
            shift_labels = labels[..., 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
759
            loss_fct = CrossEntropyLoss()
760
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
761
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
762
763

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)