"megatron/git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "ed6806ac35e84a8801da4f05766feba47cbb693b"
modeling_gpt2.py 34.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

18

thomwolf's avatar
thomwolf committed
19
20
import logging
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
21
import warnings
22
23
from dataclasses import dataclass
from typing import List, Optional, Tuple
thomwolf's avatar
thomwolf committed
24
25
26
27
28

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

29
from .activations import ACT2FN
30
from .configuration_gpt2 import GPT2Config
31
32
33
34
35
36
37
38
from .file_utils import (
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
    replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
39
40
41
42
43
44
45
from .modeling_utils import (
    Conv1D,
    PreTrainedModel,
    SequenceSummary,
    find_pruneable_heads_and_indices,
    prune_conv1d_layer,
)
Aymeric Augustin's avatar
Aymeric Augustin committed
46

thomwolf's avatar
thomwolf committed
47
48
49

logger = logging.getLogger(__name__)

50
_CONFIG_FOR_DOC = "GPT2Config"
51
52
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"

53
54
55
56
57
58
59
60
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "gpt2",
    "gpt2-medium",
    "gpt2-large",
    "gpt2-xl",
    "distilgpt2",
    # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
]
61

thomwolf's avatar
thomwolf committed
62

63
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
thomwolf's avatar
thomwolf committed
64
65
66
67
68
69
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import re
        import tensorflow as tf
    except ImportError:
70
71
72
73
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
74
75
        raise
    tf_path = os.path.abspath(gpt2_checkpoint_path)
thomwolf's avatar
thomwolf committed
76
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
thomwolf's avatar
thomwolf committed
77
78
79
80
81
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
82
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
thomwolf's avatar
thomwolf committed
83
84
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
thomwolf's avatar
thomwolf committed
85
        arrays.append(array.squeeze())
thomwolf's avatar
thomwolf committed
86
87

    for name, array in zip(names, arrays):
thomwolf's avatar
thomwolf committed
88
        name = name[6:]  # skip "model/"
89
        name = name.split("/")
thomwolf's avatar
thomwolf committed
90
91
        pointer = model
        for m_name in name:
92
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
93
                scope_names = re.split(r"(\d+)", m_name)
thomwolf's avatar
thomwolf committed
94
            else:
95
96
                scope_names = [m_name]
            if scope_names[0] == "w" or scope_names[0] == "g":
97
                pointer = getattr(pointer, "weight")
98
            elif scope_names[0] == "b":
99
                pointer = getattr(pointer, "bias")
100
101
            elif scope_names[0] == "wpe" or scope_names[0] == "wte":
                pointer = getattr(pointer, scope_names[0])
102
                pointer = getattr(pointer, "weight")
thomwolf's avatar
thomwolf committed
103
            else:
104
105
106
                pointer = getattr(pointer, scope_names[0])
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
107
108
109
110
111
112
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
113
        logger.info("Initialize PyTorch weight {}".format(name))
thomwolf's avatar
thomwolf committed
114
115
116
117
118
        pointer.data = torch.from_numpy(array)
    return model


class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
119
    def __init__(self, nx, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
120
        super().__init__()
thomwolf's avatar
thomwolf committed
121

thomwolf's avatar
thomwolf committed
122
123
124
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
125
126
127
128
        self.register_buffer(
            "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))
thomwolf's avatar
thomwolf committed
129
130
131
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
132

thomwolf's avatar
thomwolf committed
133
134
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
135
136
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
137
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
138

139
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
140
141
        if len(heads) == 0:
            return
142
143
144
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
        )
145
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
146

147
148
149
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
150

151
152
153
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
154
        self.pruned_heads = self.pruned_heads.union(heads)
155

156
    def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
thomwolf's avatar
thomwolf committed
157
158
        w = torch.matmul(q, k)
        if self.scale:
159
            w = w / (float(v.size(-1)) ** 0.5)
thomwolf's avatar
thomwolf committed
160
        nd, ns = w.size(-2), w.size(-1)
161
        mask = self.bias[:, :, ns - nd : ns, :ns]
162
        w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
thomwolf's avatar
thomwolf committed
163

164
165
166
167
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

thomwolf's avatar
thomwolf committed
168
        w = nn.Softmax(dim=-1)(w)
169
        w = self.attn_dropout(w)
170
171
172
173
174

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
175
        outputs = [torch.matmul(w, v)]
176
        if output_attentions:
thomwolf's avatar
thomwolf committed
177
178
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
179
180
181
182
183
184
185
186
187
188

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
thomwolf's avatar
thomwolf committed
189
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
thomwolf's avatar
thomwolf committed
190
        else:
thomwolf's avatar
thomwolf committed
191
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
thomwolf's avatar
thomwolf committed
192

193
194
195
    def forward(
        self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
    ):
thomwolf's avatar
thomwolf committed
196
197
198
199
200
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
thomwolf's avatar
thomwolf committed
201
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
202
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
thomwolf's avatar
thomwolf committed
203
            key = torch.cat((past_key, key), dim=-1)
thomwolf's avatar
thomwolf committed
204
            value = torch.cat((past_value, value), dim=-2)
205
206
207
208
209

        if use_cache is True:
            present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
        else:
            present = (None,)
210

211
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
thomwolf's avatar
thomwolf committed
212
        a = attn_outputs[0]
213

thomwolf's avatar
thomwolf committed
214
215
        a = self.merge_heads(a)
        a = self.c_proj(a)
216
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
217
218
219

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)
thomwolf's avatar
thomwolf committed
220
221
222
223


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
Julien Chaumond's avatar
Julien Chaumond committed
224
        super().__init__()
thomwolf's avatar
thomwolf committed
225
226
227
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
228
        self.act = ACT2FN[config.activation_function]
229
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
230
231
232
233

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
234
        return self.dropout(h2)
thomwolf's avatar
thomwolf committed
235
236
237


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
238
    def __init__(self, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
239
        super().__init__()
thomwolf's avatar
thomwolf committed
240
        nx = config.n_embd
241
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
242
        self.attn = Attention(nx, n_ctx, config, scale)
243
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
244
245
        self.mlp = MLP(4 * nx, config)

246
247
248
    def forward(
        self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False,
    ):
249
        output_attn = self.attn(
250
251
252
253
254
            self.ln_1(x),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
255
            output_attentions=output_attentions,
256
        )
thomwolf's avatar
thomwolf committed
257
258
        a = output_attn[0]  # output_attn: a, present, (attentions)

thomwolf's avatar
thomwolf committed
259
        x = x + a
thomwolf's avatar
thomwolf committed
260
        m = self.mlp(self.ln_2(x))
thomwolf's avatar
thomwolf committed
261
        x = x + m
thomwolf's avatar
thomwolf committed
262
263
264

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)
thomwolf's avatar
thomwolf committed
265
266


267
class GPT2PreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
268
    """ An abstract class to handle weights initialization and
269
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
270
    """
271

272
273
274
    config_class = GPT2Config
    load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
thomwolf's avatar
thomwolf committed
275

276
    def __init__(self, *inputs, **kwargs):
Julien Chaumond's avatar
Julien Chaumond committed
277
        super().__init__(*inputs, **kwargs)
278

279
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
280
281
        """ Initialize the weights.
        """
282
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
283
284
285
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
286
287
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
288
        elif isinstance(module, nn.LayerNorm):
thomwolf's avatar
thomwolf committed
289
290
291
292
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
@dataclass
class GPT2DoubleHeadsModelOutput(ModelOutput):
    """
    Base class for outputs of models predicting if two sentences are consecutive or not.

    Args:
        lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
            Language modeling loss.
        mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
            Multiple choice classification loss.
        lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
            List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`,  with each tensor of shape
            :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
            ``past_key_values`` input) to speed up sequential decoding.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    lm_loss: Optional[torch.FloatTensor]
    mc_loss: Optional[torch.FloatTensor]
    lm_logits: torch.FloatTensor
    mc_logits: torch.FloatTensor
    past_key_values: Optional[List[torch.FloatTensor]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


Lysandre's avatar
Lysandre committed
335
336
337
338
GPT2_START_DOCSTRING = r"""

    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
Lysandre's avatar
Fixes  
Lysandre committed
339
    usage and behavior.
thomwolf's avatar
thomwolf committed
340
341

    Parameters:
342
        config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
343
            Initializing with a config file does not load the weights associated with the model, only the configuration.
344
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
345
346
"""

Lysandre's avatar
Lysandre committed
347
GPT2_INPUTS_DOCSTRING = r"""
348
    Args:
349
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
350
351
            :obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
            ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
Lysandre's avatar
Lysandre committed
352
            Indices of input sequence tokens in the vocabulary.
353

354
355
            If ``past_key_values`` is used, only ``input_ids`` that do not have their past calculated should be passed
            as ``input_ids``.
Lysandre's avatar
Lysandre committed
356

357
358
            Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
359
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
Lysandre's avatar
Lysandre committed
360

361
            `What are input IDs? <../glossary.html#input-ids>`__
362

363
        past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
364
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
365
366
            (see ``past_key_values`` output below). Can be used to speed up sequential decoding.
            The ``input_ids`` which have their past given to this model should not be passed as ``input_ids`` as they have already been computed.
367
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
368
369
370
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
371

372
            `What are attention masks? <../glossary.html#attention-mask>`__
373
374
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`, defaults to :obj:`None`):
            `input_ids_length` = `sequence_length if `past` is None else 1
375
376
377
378
379
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
380
381
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
382

383
384
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
385
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
386
            Mask values selected in ``[0, 1]``:
387
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
flozi00's avatar
flozi00 committed
388
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
389
390
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
391
            If ``past_key_values`` is used, optionally only the last `inputs_embeds` have to be input (see ``past_key_values``).
392
        use_cache (:obj:`bool`):
393
            If `use_cache` is True, ``past_key_values`` key value states are returned and can be used to speed up decoding (see ``past_key_values``). Defaults to `True`.
ZhuBaohe's avatar
ZhuBaohe committed
394
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
395
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
396
397
398
399
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
        return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
            If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
thomwolf's avatar
thomwolf committed
400
401
"""

402
403
404
405
406

@add_start_docstrings(
    "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
407
class GPT2Model(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
408
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
409
        super().__init__(config)
thomwolf's avatar
thomwolf committed
410

thomwolf's avatar
thomwolf committed
411
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
thomwolf's avatar
thomwolf committed
412
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
413
        self.drop = nn.Dropout(config.embd_pdrop)
414
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
415
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
416

417
        self.init_weights()
thomwolf's avatar
thomwolf committed
418

thomwolf's avatar
thomwolf committed
419
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
420
        return self.wte
thomwolf's avatar
thomwolf committed
421

thomwolf's avatar
thomwolf committed
422
    def set_input_embeddings(self, new_embeddings):
423
424
        self.wte = new_embeddings

thomwolf's avatar
thomwolf committed
425
    def _prune_heads(self, heads_to_prune):
426
427
428
429
430
431
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

432
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
433
434
435
436
437
438
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="gpt2",
        output_type=BaseModelOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
    )
439
440
441
    def forward(
        self,
        input_ids=None,
442
        past_key_values=None,
443
444
445
446
447
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
448
        use_cache=None,
449
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
450
        output_hidden_states=None,
451
        return_tuple=None,
452
        **kwargs,
453
    ):
454
455
456
457
458
459
460
461
        if "past" in kwargs:
            warnings.warn(
                "The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("past")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

462
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
463
464
465
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
466
        use_cache = use_cache if use_cache is not None else self.config.use_cache
467
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
468

Julien Chaumond's avatar
Julien Chaumond committed
469
470
471
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
472
473
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
474
            batch_size = input_ids.shape[0]
475
476
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
477
            batch_size = inputs_embeds.shape[0]
478
479
480
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

481
482
483
484
485
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

486
        if past_key_values is None:
thomwolf's avatar
thomwolf committed
487
            past_length = 0
488
            past_key_values = [None] * len(self.h)
thomwolf's avatar
thomwolf committed
489
        else:
490
            past_length = past_key_values[0][0].size(-2)
thomwolf's avatar
thomwolf committed
491
        if position_ids is None:
492
493
494
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
thomwolf's avatar
thomwolf committed
495

496
497
        # Attention mask.
        if attention_mask is not None:
498
            assert batch_size > 0, "batch_size has to be defined and > 0"
499
            attention_mask = attention_mask.view(batch_size, -1)
500
501
502
503
504
505
506
507
508
509
510
511
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
512
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
513
514
            attention_mask = (1.0 - attention_mask) * -10000.0

515
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
516
        # 1.0 in head_mask indicate we keep the head
517
        # attention_probs has shape bsz x n_heads x N x N
518
        # head_mask has shape n_layer x batch x n_heads x N x N
519
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
520

521
522
        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
thomwolf's avatar
thomwolf committed
523
524
525
526
527
528
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
529
530
        hidden_states = self.drop(hidden_states)

531
532
        output_shape = input_shape + (hidden_states.size(-1),)

533
534
535
        presents = () if use_cache else None
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None
536
        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
Joseph Liu's avatar
Joseph Liu committed
537
            if output_hidden_states:
538
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
539

540
            outputs = block(
541
542
543
544
545
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
546
                output_attentions=output_attentions,
547
            )
548

thomwolf's avatar
thomwolf committed
549
            hidden_states, present = outputs[:2]
550
            if use_cache is True:
551
                presents = presents + (present,)
thomwolf's avatar
thomwolf committed
552

553
            if output_attentions:
554
                all_attentions = all_attentions + (outputs[2],)
thomwolf's avatar
thomwolf committed
555

thomwolf's avatar
thomwolf committed
556
        hidden_states = self.ln_f(hidden_states)
557

thomwolf's avatar
thomwolf committed
558
559
        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
Joseph Liu's avatar
Joseph Liu committed
560
        if output_hidden_states:
561
            all_hidden_states = all_hidden_states + (hidden_states,)
thomwolf's avatar
thomwolf committed
562

563
564
565
566
567
568
569
570
571
        if return_tuple:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )
thomwolf's avatar
thomwolf committed
572
573


574
@add_start_docstrings(
Lysandre's avatar
Lysandre committed
575
    """The GPT2 Model transformer with a language modeling head on top
576
    (linear layer with weights tied to the input embeddings). """,
577
578
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
579
class GPT2LMHeadModel(GPT2PreTrainedModel):
thomwolf's avatar
thomwolf committed
580
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
581
        super().__init__(config)
thomwolf's avatar
thomwolf committed
582
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
583
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
584

585
        self.init_weights()
586

thomwolf's avatar
thomwolf committed
587
    def get_output_embeddings(self):
588
        return self.lm_head
thomwolf's avatar
thomwolf committed
589

590
    def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
591
        # only last token for inputs_ids if past is defined in kwargs
592
        if past:
593
            input_ids = input_ids[:, -1].unsqueeze(-1)
594

595
        return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
596

597
    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
598
599
600
601
602
603
    @add_code_sample_docstrings(
        tokenizer_class=_TOKENIZER_FOR_DOC,
        checkpoint="ctrl",
        output_type=CausalLMOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
    )
604
605
606
    def forward(
        self,
        input_ids=None,
607
        past_key_values=None,
608
609
610
611
612
613
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
614
        use_cache=None,
615
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
616
        output_hidden_states=None,
617
        return_tuple=None,
618
        **kwargs,
619
    ):
620
621
622
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for language modeling.
623
            Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Lysandre's avatar
Lysandre committed
624
625
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
            All labels set to ``-100`` are ignored (masked), the loss is only
626
627
            computed for labels in ``[0, ..., config.vocab_size]``
        """
628
629
630
631
632
633
634
        if "past" in kwargs:
            warnings.warn(
                "The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("past")
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
635
636
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple

637
638
        transformer_outputs = self.transformer(
            input_ids,
639
            past_key_values=past_key_values,
640
641
642
643
644
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
645
            use_cache=use_cache,
646
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
647
            output_hidden_states=output_hidden_states,
648
            return_tuple=return_tuple,
649
        )
thomwolf's avatar
thomwolf committed
650
        hidden_states = transformer_outputs[0]
651

thomwolf's avatar
thomwolf committed
652
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
653

654
        loss = None
thomwolf's avatar
thomwolf committed
655
        if labels is not None:
656
            # Shift so that tokens < n predict n
657
            shift_logits = lm_logits[..., :-1, :].contiguous()
thomwolf's avatar
thomwolf committed
658
            shift_labels = labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
659
            # Flatten the tokens
LysandreJik's avatar
LysandreJik committed
660
            loss_fct = CrossEntropyLoss()
661
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
thomwolf's avatar
thomwolf committed
662

663
664
665
666
667
668
669
670
671
672
673
        if return_tuple:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
674
675


676
677
@add_start_docstrings(
    """The GPT2 Model transformer with a language modeling and a multiple-choice classification
678
679
680
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
681
682
683
""",
    GPT2_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
684
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
685
686
687
688
689
690
691
692
693
694
695
696
697
    def __init__(self, config):
        super().__init__(config)
        config.num_labels = 1
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.multiple_choice_head = SequenceSummary(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
698
    @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
699
700
701
    def forward(
        self,
        input_ids=None,
702
        past_key_values=None,
703
704
705
706
707
708
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
Sylvain Gugger's avatar
Sylvain Gugger committed
709
        labels=None,
710
        mc_labels=None,
711
        use_cache=None,
712
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
713
        output_hidden_states=None,
714
        return_tuple=None,
715
        **kwargs,
716
717
718
    ):
        r"""
        mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
thomwolf's avatar
thomwolf committed
719
720
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
Sylvain Gugger's avatar
Sylvain Gugger committed
721
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
722
            Labels for language modeling.
Sylvain Gugger's avatar
Sylvain Gugger committed
723
            Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
thomwolf's avatar
thomwolf committed
724
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
725
            All labels set to ``-100`` are ignored (masked), the loss is only
thomwolf's avatar
thomwolf committed
726
            computed for labels in ``[0, ..., config.vocab_size]``
727
        mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
728
729
730
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
Sylvain Gugger's avatar
Sylvain Gugger committed
731
732
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
thomwolf's avatar
thomwolf committed
733

734
    Return:
thomwolf's avatar
thomwolf committed
735
736
737

    Examples::

738
739
740
741
742
        >>> import torch
        >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel

        >>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        >>> model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
743

744
745
        >>> # Add a [CLS] to the vocabulary (we should train it also!)
        >>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'})
746

747
        >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))  # Update the model embeddings with the new vocabulary size
748

749
750
751
        >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
        >>> encoded_choices = [tokenizer.encode(s) for s in choices]
        >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
thomwolf's avatar
thomwolf committed
752

753
754
        >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        >>> mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1
thomwolf's avatar
thomwolf committed
755

756
        >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
757
758
        >>> lm_logits = outputs.lm_logits
        >>> mc_logits = outputs.mc_logits
thomwolf's avatar
thomwolf committed
759

760
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
761
762
763
        if "lm_labels" in kwargs:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
764
                FutureWarning,
Sylvain Gugger's avatar
Sylvain Gugger committed
765
766
            )
            labels = kwargs.pop("lm_labels")
767
768
769
770
771
772
        if "past" in kwargs:
            warnings.warn(
                "The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("past")
Sylvain Gugger's avatar
Sylvain Gugger committed
773
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
774
        return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
Sylvain Gugger's avatar
Sylvain Gugger committed
775

776
777
        transformer_outputs = self.transformer(
            input_ids,
778
            past_key_values=past_key_values,
779
780
781
782
783
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
784
            use_cache=use_cache,
785
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
786
            output_hidden_states=output_hidden_states,
787
            return_tuple=return_tuple,
788
        )
789

thomwolf's avatar
thomwolf committed
790
        hidden_states = transformer_outputs[0]
791

thomwolf's avatar
thomwolf committed
792
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
793
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
794

795
        mc_loss = None
thomwolf's avatar
thomwolf committed
796
797
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
798
799
            mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
        lm_loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
800
        if labels is not None:
801
            shift_logits = lm_logits[..., :-1, :].contiguous()
Sylvain Gugger's avatar
Sylvain Gugger committed
802
            shift_labels = labels[..., 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
803
            loss_fct = CrossEntropyLoss()
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
            lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if return_tuple:
            output = (lm_logits, mc_logits) + transformer_outputs[1:]
            if mc_loss is not None:
                output = (mc_loss,) + output
            return ((lm_loss,) + output) if lm_loss is not None else output

        return GPT2DoubleHeadsModelOutput(
            lm_loss=lm_loss,
            mc_loss=mc_loss,
            lm_logits=lm_logits,
            mc_logits=mc_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )