modeling_gpt2.py 36.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

18
19
from __future__ import absolute_import, division, print_function, unicode_literals

thomwolf's avatar
thomwolf committed
20
21
22
23
24
25
26
27
28
29
30
31
32
import collections
import json
import logging
import math
import os
import sys
from io import open

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter

33
from .file_utils import cached_path
34
from .model_utils import Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, prune_conv1d_layer
thomwolf's avatar
thomwolf committed
35
36
37
38
from .modeling import BertLayerNorm as LayerNorm

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
39
40
41
42
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
                                "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
                                 "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
thomwolf's avatar
thomwolf committed
43

44
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
thomwolf's avatar
thomwolf committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import re
        import numpy as np
        import tensorflow as tf
    except ImportError:
        print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions.")
        raise
    tf_path = os.path.abspath(gpt2_checkpoint_path)
    print("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        print("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
thomwolf's avatar
thomwolf committed
65
        arrays.append(array.squeeze())
thomwolf's avatar
thomwolf committed
66
67

    for name, array in zip(names, arrays):
thomwolf's avatar
thomwolf committed
68
        name = name[6:]  # skip "model/"
thomwolf's avatar
thomwolf committed
69
70
71
        name = name.split('/')
        pointer = model
        for m_name in name:
thomwolf's avatar
thomwolf committed
72
73
            if re.fullmatch(r'[A-Za-z]+\d+', m_name):
                l = re.split(r'(\d+)', m_name)
thomwolf's avatar
thomwolf committed
74
75
76
77
78
79
            else:
                l = [m_name]
            if l[0] == 'w' or l[0] == 'g':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'b':
                pointer = getattr(pointer, 'bias')
thomwolf's avatar
thomwolf committed
80
81
82
            elif l[0] == 'wpe' or l[0] == 'wte':
                pointer = getattr(pointer, l[0])
                pointer = getattr(pointer, 'weight')
thomwolf's avatar
thomwolf committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            else:
                pointer = getattr(pointer, l[0])
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        print("Initialize PyTorch weight {}".format(name))
        pointer.data = torch.from_numpy(array)
    return model


def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


102
class GPT2Config(PretrainedConfig):
thomwolf's avatar
thomwolf committed
103
104
    """Configuration class to store the configuration of a `GPT2Model`.
    """
105
    pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
thomwolf's avatar
thomwolf committed
106
107
108

    def __init__(
        self,
thomwolf's avatar
thomwolf committed
109
        vocab_size_or_config_json_file=50257,
thomwolf's avatar
thomwolf committed
110
        n_special=0,
thomwolf's avatar
thomwolf committed
111
112
113
114
115
        n_positions=1024,
        n_ctx=1024,
        n_embd=768,
        n_layer=12,
        n_head=12,
116
117
118
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
thomwolf's avatar
thomwolf committed
119
120
        layer_norm_epsilon=1e-5,
        initializer_range=0.02,
thomwolf's avatar
thomwolf committed
121
122
        predict_special_tokens=True,
        **kwargs
thomwolf's avatar
thomwolf committed
123
124
125
126
127
    ):
        """Constructs GPT2Config.

        Args:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
thomwolf's avatar
thomwolf committed
128
            n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
thomwolf's avatar
thomwolf committed
129
130
131
132
133
134
135
            n_positions: Number of positional embeddings.
            n_ctx: Size of the causal mask (usually same as n_positions).
            n_embd: Dimensionality of the embeddings and hidden states.
            n_layer: Number of hidden layers in the Transformer encoder.
            n_head: Number of attention heads for each attention layer in
                the Transformer encoder.
            layer_norm_epsilon: epsilon to use in the layer norm layers
136
137
138
139
140
            resid_pdrop: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attn_pdrop: The dropout ratio for the attention
                probabilities.
            embd_pdrop: The dropout ratio for the embeddings.
thomwolf's avatar
thomwolf committed
141
142
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
143
            predict_special_tokens: should we predict special tokens (when the model has a LM head)
thomwolf's avatar
thomwolf committed
144
        """
thomwolf's avatar
thomwolf committed
145
146
        super(GPT2Config, self).__init__(**kwargs)

thomwolf's avatar
thomwolf committed
147
148
149
150
151
152
153
154
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
thomwolf's avatar
thomwolf committed
155
            self.n_special = n_special
thomwolf's avatar
thomwolf committed
156
157
158
159
160
            self.n_ctx = n_ctx
            self.n_positions = n_positions
            self.n_embd = n_embd
            self.n_layer = n_layer
            self.n_head = n_head
161
162
163
            self.resid_pdrop = resid_pdrop
            self.embd_pdrop = embd_pdrop
            self.attn_pdrop = attn_pdrop
thomwolf's avatar
thomwolf committed
164
165
            self.layer_norm_epsilon = layer_norm_epsilon
            self.initializer_range = initializer_range
166
            self.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
167
168
169
170
171
172
        else:
            raise ValueError(
                "First argument must be either a vocabulary size (int)"
                "or the path to a pretrained model config file (str)"
            )

thomwolf's avatar
thomwolf committed
173
174
175
176
    @property
    def total_tokens_embeddings(self):
        return self.vocab_size + self.n_special

thomwolf's avatar
thomwolf committed
177
178
179
180
181
182
183
184
185
186
187
188
189
    @property
    def hidden_size(self):
        return self.n_embd

    @property
    def num_attention_heads(self):
        return self.n_head

    @property
    def num_hidden_layers(self):
        return self.n_layer


thomwolf's avatar
thomwolf committed
190
191

class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
192
    def __init__(self, nx, n_ctx, config, scale=False):
thomwolf's avatar
thomwolf committed
193
        super(Attention, self).__init__()
thomwolf's avatar
thomwolf committed
194
195
        self.output_attentions = config.output_attentions

thomwolf's avatar
thomwolf committed
196
197
198
199
200
201
202
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
203

thomwolf's avatar
thomwolf committed
204
205
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
206
207
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
208

209
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
210
211
        if len(heads) == 0:
            return
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
        for head in heads:
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)

    def _attn(self, q, k, v, head_mask=None):
thomwolf's avatar
thomwolf committed
226
227
228
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
thomwolf's avatar
thomwolf committed
229
230
        nd, ns = w.size(-2), w.size(-1)
        b = self.bias[:, :, ns-nd:ns, :ns]
231
        w = w * b - 1e4 * (1 - b)
thomwolf's avatar
thomwolf committed
232
233

        w = nn.Softmax(dim=-1)(w)
234
        w = self.attn_dropout(w)
235
236
237
238
239

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
240
        outputs = [torch.matmul(w, v)]
thomwolf's avatar
thomwolf committed
241
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
242
243
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
244
245
246
247
248
249
250
251
252
253

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
thomwolf's avatar
thomwolf committed
254
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
thomwolf's avatar
thomwolf committed
255
        else:
thomwolf's avatar
thomwolf committed
256
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
thomwolf's avatar
thomwolf committed
257

258
    def forward(self, x, layer_past=None, head_mask=None):
thomwolf's avatar
thomwolf committed
259
260
261
262
263
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
thomwolf's avatar
thomwolf committed
264
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
265
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
thomwolf's avatar
thomwolf committed
266
            key = torch.cat((past_key, key), dim=-1)
thomwolf's avatar
thomwolf committed
267
            value = torch.cat((past_value, value), dim=-2)
thomwolf's avatar
thomwolf committed
268
        present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
269

thomwolf's avatar
thomwolf committed
270
271
        attn_outputs = self._attn(query, key, value, head_mask)
        a = attn_outputs[0]
272

thomwolf's avatar
thomwolf committed
273
274
        a = self.merge_heads(a)
        a = self.c_proj(a)
275
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
276
277
278

        outputs = [a, present] + attn_outputs[1:]
        return outputs  # a, present, (attentions)
thomwolf's avatar
thomwolf committed
279
280
281
282
283
284
285
286
287


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
        super(MLP, self).__init__()
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
        self.act = gelu
288
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
289
290
291
292

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
293
        return self.dropout(h2)
thomwolf's avatar
thomwolf committed
294
295
296


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
297
    def __init__(self, n_ctx, config, scale=False):
thomwolf's avatar
thomwolf committed
298
299
300
        super(Block, self).__init__()
        nx = config.n_embd
        self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
301
        self.attn = Attention(nx, n_ctx, config, scale)
thomwolf's avatar
thomwolf committed
302
303
304
        self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = MLP(4 * nx, config)

305
306
    def forward(self, x, layer_past=None, head_mask=None):
        output_attn = self.attn(self.ln_1(x), layer_past=layer_past, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
307
308
        a = output_attn[0]  # output_attn: a, present, (attentions)

thomwolf's avatar
thomwolf committed
309
        x = x + a
thomwolf's avatar
thomwolf committed
310
        m = self.mlp(self.ln_2(x))
thomwolf's avatar
thomwolf committed
311
        x = x + m
thomwolf's avatar
thomwolf committed
312
313
314

        outputs = [x] + output_attn[1:]
        return outputs  # x, present, (attentions)
thomwolf's avatar
thomwolf committed
315
316
317
318
319
320
321
322


class GPT2LMHead(nn.Module):
    """ Language Model Head for the transformer """

    def __init__(self, model_embeddings_weights, config):
        super(GPT2LMHead, self).__init__()
        self.n_embd = config.n_embd
323
324
        self.vocab_size = config.vocab_size
        self.predict_special_tokens = config.predict_special_tokens
thomwolf's avatar
thomwolf committed
325
326
327
328
        embed_shape = model_embeddings_weights.shape
        self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
        self.set_embeddings_weights(model_embeddings_weights)

329
330
    def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
        self.predict_special_tokens = predict_special_tokens
331
332
        # Export to TorchScript can't handle parameter sharing so we are cloning them.
        self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())  # Tied weights
thomwolf's avatar
thomwolf committed
333
334
335

    def forward(self, hidden_state):
        lm_logits = self.decoder(hidden_state)
336
337
        if not self.predict_special_tokens:
            lm_logits = lm_logits[..., :self.vocab_size]
thomwolf's avatar
thomwolf committed
338
339
340
341
342
343
344
345
346
        return lm_logits


class GPT2MultipleChoiceHead(nn.Module):
    """ Classifier Head for the transformer """

    def __init__(self, config):
        super(GPT2MultipleChoiceHead, self).__init__()
        self.n_embd = config.n_embd
347
        self.dropout = nn.Dropout2d(config.resid_pdrop)  # To reproduce the noise_shape parameter of TF implementation
thomwolf's avatar
thomwolf committed
348
349
350
351
352
        self.linear = nn.Linear(config.n_embd, 1)

        nn.init.normal_(self.linear.weight, std=0.02)
        nn.init.normal_(self.linear.bias, 0)

thomwolf's avatar
thomwolf committed
353
354
355
356
357
358
359
360
361
362
363
    def forward(self, hidden_states, mc_token_ids=None):
        """ Extract classification token hidden state and project it using self.linear
            hidden_state: shape (bsz, num_choices, seq_length, hidden_size)
            mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
            if mc_token_ids=None we take the last token of the sequence as classification token
        """
        if mc_token_ids is None:
            mc_token_ids = torch.full_like(hidden_states[:, :, :1, :], hidden_states.shape[2] - 1, dtype=torch.long)
        else:
            mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
        # mc_token_ids has shape (bsz, num_choices, 1, hidden_size)
thomwolf's avatar
thomwolf committed
364
365
        multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
        # (bsz, num_choices, hidden_size)
366
        multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
thomwolf's avatar
thomwolf committed
367
368
369
370
371
        multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
        # (bsz, num_choices)
        return multiple_choice_logits


372
class GPT2PreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
373
374
375
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
376
377
378
379
    config_class = GPT2Config
    pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
    load_tf_weights = load_tf_weights_in_gpt2
    base_model_prefix = "transformer"
thomwolf's avatar
thomwolf committed
380

381
382
383
    def __init__(self, *inputs, **kwargs):
        super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)

thomwolf's avatar
thomwolf committed
384
385
386
    def init_weights(self, module):
        """ Initialize the weights.
        """
387
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
388
389
390
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
391
392
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
thomwolf's avatar
thomwolf committed
393
394
395
396
397
        elif isinstance(module, LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    @classmethod
VictorSanh's avatar
VictorSanh committed
398
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
thomwolf's avatar
thomwolf committed
399
400
401
402
403
404
405
        """
        Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.

        Params:
            pretrained_model_name_or_path: either:
                - a str with the name of a pre-trained model to load selected in the list of:
Joel Grus's avatar
Joel Grus committed
406
                    . `gpt2`
thomwolf's avatar
thomwolf committed
407
408
409
410
                - a path or url to a pretrained model archive containing:
                    . `gpt2_config.json` a configuration file for the model
                    . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance
                - a path or url to a pretrained model archive containing:
Joel Grus's avatar
Joel Grus committed
411
                    . `gpt2_config.json` a configuration file for the model
thomwolf's avatar
thomwolf committed
412
413
414
                    . a TensorFlow checkpoint with trained weights
            from_tf: should we load the weights from a locally saved TensorFlow checkpoint
            cache_dir: an optional path to a folder in which the pre-trained models will be cached.
Joel Grus's avatar
Joel Grus committed
415
            state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
VictorSanh's avatar
VictorSanh committed
416
            *inputs, **kwargs: additional input for the specific GPT2 class
thomwolf's avatar
thomwolf committed
417
        """
thomwolf's avatar
thomwolf committed
418
419
420
        num_special_tokens = kwargs.pop('num_special_tokens', None)

        model = PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
thomwolf's avatar
thomwolf committed
421

thomwolf's avatar
thomwolf committed
422
423
        # Add additional embeddings for special tokens if needed
        # This step also make sure we are still sharing the output and input embeddings after loading weights
424
        model.set_num_special_tokens(num_special_tokens)
thomwolf's avatar
thomwolf committed
425
426
427
428
429
430
        return model


class GPT2Model(GPT2PreTrainedModel):
    """OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").

thomwolf's avatar
thomwolf committed
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    GPT-2 use a single embedding matrix to store the word and special embeddings.
    Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
    Special tokens need to be trained during the fine-tuning if you use them.
    The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.

    The embeddings are ordered as follow in the token embeddings matrice:
        [0,                                                         ----------------------
         ...                                                        -> word embeddings
         config.vocab_size - 1,                                     ______________________
         config.vocab_size,
         ...                                                        -> special embeddings
         config.vocab_size + config.n_special - 1]                  ______________________

    where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
        total_tokens_embeddings = config.vocab_size + config.n_special
    You should use the associate indices to index the embeddings.

thomwolf's avatar
thomwolf committed
448
    Params:
449
450
        `config`: a GPT2Config class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
thomwolf's avatar
thomwolf committed
451
452
453
454
455
456
457
458
459
460
461

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
            were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
        `position_ids`: an optional torch.LongTensor with the same shape as input_ids
            with the position indices (selected in the range [0, config.n_positions - 1[.
        `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
            You can use it to add a third type of embedding to each input token in the sequence
            (the previous two being the word and position embeddings).
            The input, position and token_type embeddings are summed inside the Transformer before the first
            self-attention block.
Joel Grus's avatar
Joel Grus committed
462
463
464
        `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
            (key and values in the attention blocks) to speed up sequential decoding
            (this is the presents output of the model, cf. below).
465
466
        `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
            It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
thomwolf's avatar
thomwolf committed
467

Joel Grus's avatar
Joel Grus committed
468
    Outputs a tuple consisting of:
469
470
        `hidden_states`: a list of all the encoded-hidden-states in the model (length of the list: number of layers + 1 for the output of the embeddings)
            as torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
thomwolf's avatar
thomwolf committed
471
            (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Joel Grus's avatar
Joel Grus committed
472
473
        `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
            torch.FloatTensors. They can be reused to speed up sequential decoding.
thomwolf's avatar
thomwolf committed
474
475
476
477
478
479
480
481
482

    Example usage:
    ```python
    # Already been converted into BPE token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])

    config = modeling_gpt2.GPT2Config()

    model = modeling_gpt2.GPT2Model(config)
Joel Grus's avatar
Joel Grus committed
483
    hidden_states, presents = model(input_ids)
thomwolf's avatar
thomwolf committed
484
485
486
    ```
    """

thomwolf's avatar
thomwolf committed
487
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
488
        super(GPT2Model, self).__init__(config)
thomwolf's avatar
thomwolf committed
489
490
491
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions

thomwolf's avatar
thomwolf committed
492
        self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
thomwolf's avatar
thomwolf committed
493
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
494
        self.drop = nn.Dropout(config.embd_pdrop)
495
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
thomwolf's avatar
thomwolf committed
496
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
497
498
499

        self.apply(self.init_weights)

500
    def set_num_special_tokens(self, num_special_tokens=None):
thomwolf's avatar
thomwolf committed
501
        " Update input embeddings with new embedding matrice if needed "
502
        if num_special_tokens is None or self.config.n_special == num_special_tokens:
thomwolf's avatar
thomwolf committed
503
504
505
506
507
508
509
510
511
512
513
            return
        # Update config
        self.config.n_special = num_special_tokens
        # Build new embeddings and initialize all new embeddings (in particular the special tokens)
        old_embed = self.wte
        self.wte = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
        self.wte.to(old_embed.weight.device)
        self.init_weights(self.wte)
        # Copy word embeddings from the previous weights
        self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]

thomwolf's avatar
thomwolf committed
514
    def _prune_heads(self, heads_to_prune):
515
516
517
518
519
520
521
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

    def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
thomwolf's avatar
thomwolf committed
522
        if past is None:
thomwolf's avatar
thomwolf committed
523
            past_length = 0
thomwolf's avatar
thomwolf committed
524
            past = [None] * len(self.h)
thomwolf's avatar
thomwolf committed
525
        else:
thomwolf's avatar
thomwolf committed
526
            past_length = past[0][0].size(-2)
thomwolf's avatar
thomwolf committed
527
528
529
530
        if position_ids is None:
            position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

531
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
532
        # 1.0 in head_mask indicate we keep the head
533
        # attention_probs has shape bsz x n_heads x N x N
534
        # head_mask has shape n_layer x batch x n_heads x N x N
535
536
        if head_mask is not None:
            if head_mask.dim() == 1:
537
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
538
                head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
539
            elif head_mask.dim() == 2:
540
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
541
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
542
543
        else:
            head_mask = [None] * self.config.n_layer
544

thomwolf's avatar
thomwolf committed
545
546
547
548
549
550
551
552
553
554
555
556
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_ids.size(-1))
        position_ids = position_ids.view(-1, position_ids.size(-1))

        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
557
558
        hidden_states = self.drop(hidden_states)

559
560
        output_shape = input_shape + (hidden_states.size(-1),)

561
        presents = ()
thomwolf's avatar
thomwolf committed
562
        all_attentions = []
563
        all_hidden_states = ()
564
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
thomwolf's avatar
thomwolf committed
565
            if self.output_hidden_states:
566
                all_hidden_states += (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
567

568
            outputs = block(hidden_states, layer_past, head_mask[i])
thomwolf's avatar
thomwolf committed
569
            hidden_states, present = outputs[:2]
570
            presents += (present,)
thomwolf's avatar
thomwolf committed
571
572
573
574

            if self.output_attentions:
                all_attentions.append(outputs[2])

thomwolf's avatar
thomwolf committed
575
        hidden_states = self.ln_f(hidden_states)
576

thomwolf's avatar
thomwolf committed
577
578
579
        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
        if self.output_hidden_states:
580
            all_hidden_states += (hidden_states,)
thomwolf's avatar
thomwolf committed
581

582
        outputs = (hidden_states, presents)
thomwolf's avatar
thomwolf committed
583
        if self.output_hidden_states:
584
            outputs += (all_hidden_states,)
thomwolf's avatar
thomwolf committed
585
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
586
587
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
588
589
            all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
            outputs += (all_attentions,)
thomwolf's avatar
thomwolf committed
590
        return outputs  # last hidden state, presents, (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
591
592
593
594
595
596


class GPT2LMHeadModel(GPT2PreTrainedModel):
    """OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").

    Params:
597
598
599
600
        `config`: a GPT2Config class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
thomwolf's avatar
thomwolf committed
601
602
603
604
605
606
607
608
609
610
611
612
613
614

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
            were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
        `position_ids`: an optional torch.LongTensor with the same shape as input_ids
            with the position indices (selected in the range [0, config.n_positions - 1[.
        `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
            You can use it to add a third type of embedding to each input token in the sequence
            (the previous two being the word and position embeddings).
            The input, position and token_type embeddings are summed inside the Transformer before the first
            self-attention block.
        `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
            with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
            is only computed for the labels set in [0, ..., vocab_size]
Joel Grus's avatar
Joel Grus committed
615
616
617
        `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
            (key and values in the attention blocks) to speed up sequential decoding
            (this is the presents output of the model, cf. below).
618
619
        `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
            It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
thomwolf's avatar
thomwolf committed
620
621
622
623

    Outputs:
        if `lm_labels` is not `None`:
            Outputs the language modeling loss.
Joel Grus's avatar
Joel Grus committed
624
        else a tuple:
thomwolf's avatar
thomwolf committed
625
626
            `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, config.vocab_size]
                (or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ... d_n are the dimension of input_ids)
Joel Grus's avatar
Joel Grus committed
627
628
            `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
                torch.FloatTensors. They can be reused to speed up sequential decoding.
thomwolf's avatar
thomwolf committed
629
630
631
632
633
634
635
636
637

    Example usage:
    ```python
    # Already been converted into BPE token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])

    config = modeling_gpt2.GPT2Config()

    model = modeling_gpt2.GPT2LMHeadModel(config)
Joel Grus's avatar
Joel Grus committed
638
    lm_logits, presents = model(input_ids)
thomwolf's avatar
thomwolf committed
639
640
641
    ```
    """

thomwolf's avatar
thomwolf committed
642
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
643
        super(GPT2LMHeadModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
644
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
645
646
647
        self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
        self.apply(self.init_weights)

648
    def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
thomwolf's avatar
thomwolf committed
649
650
        """ Update input and output embeddings with new embedding matrice
            Make sure we are sharing the embeddings
thomwolf's avatar
thomwolf committed
651
        """
652
        self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
653
        self.transformer.set_num_special_tokens(num_special_tokens)
654
        self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
thomwolf's avatar
thomwolf committed
655

656
    def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
thomwolf's avatar
thomwolf committed
657
658
        transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
        hidden_states = transformer_outputs[0]
659

thomwolf's avatar
thomwolf committed
660
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
661

662
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
663
        if lm_labels is not None:
664
            # Shift so that tokens < n predict n
665
666
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
667
            # Flatten the tokens
thomwolf's avatar
thomwolf committed
668
            loss_fct = CrossEntropyLoss(ignore_index=-1)
669
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
670
                            shift_labels.view(-1))
671
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
672
673

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)
thomwolf's avatar
thomwolf committed
674
675
676
677
678
679


class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
    """OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").

    Params:
680
681
682
683
        `config`: a GPT2Config class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
thomwolf's avatar
thomwolf committed
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
            indices selected in the range [0, config.vocab_size[
        `mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
            which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
        `position_ids`: an optional torch.LongTensor with the same shape as input_ids
            with the position indices (selected in the range [0, config.n_positions - 1[.
        `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
            You can use it to add a third type of embedding to each input token in the sequence
            (the previous two being the word and position embeddings).
            The input, position and token_type embeddings are summed inside the Transformer before the first
            self-attention block.
        `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
            with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss
            is only computed for the labels set in [0, ..., config.vocab_size]
        `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
            with indices selected in [0, ..., num_choices].
Joel Grus's avatar
Joel Grus committed
702
703
704
        `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
            (key and values in the attention blocks) to speed up sequential decoding
            (this is the presents output of the model, cf. below).
705
706
        `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
            It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
thomwolf's avatar
thomwolf committed
707
708
709
710
711
712
713

    Outputs:
        if `lm_labels` and `multiple_choice_labels` are not `None`:
            Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
        else: a tuple with
            `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, config.vocab_size]
            `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
Joel Grus's avatar
Joel Grus committed
714
715
            `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
                torch.FloatTensors. They can be reused to speed up sequential decoding.
thomwolf's avatar
thomwolf committed
716
717
718
719
720
721
722
723
724

    Example usage:
    ```python
    # Already been converted into BPE token ids
    input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]])  # (bsz, number of choice, seq length)
    mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)

    config = modeling_gpt2.GPT2Config()

VictorSanh's avatar
VictorSanh committed
725
    model = modeling_gpt2.GPT2DoubleHeadsModel(config)
Joel Grus's avatar
Joel Grus committed
726
    lm_logits, multiple_choice_logits, presents = model(input_ids, mc_token_ids)
thomwolf's avatar
thomwolf committed
727
728
729
    ```
    """

thomwolf's avatar
thomwolf committed
730
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
731
        super(GPT2DoubleHeadsModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
732
        self.transformer = GPT2Model(config)
thomwolf's avatar
thomwolf committed
733
734
        self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
        self.multiple_choice_head = GPT2MultipleChoiceHead(config)
thomwolf's avatar
thomwolf committed
735

thomwolf's avatar
thomwolf committed
736
737
        self.apply(self.init_weights)

738
    def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
thomwolf's avatar
thomwolf committed
739
740
        """ Update input and output embeddings with new embedding matrice
            Make sure we are sharing the embeddings
thomwolf's avatar
thomwolf committed
741
        """
742
        self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
743
        self.transformer.set_num_special_tokens(num_special_tokens)
744
        self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
thomwolf's avatar
thomwolf committed
745

thomwolf's avatar
thomwolf committed
746
    def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
747
                position_ids=None, past=None, head_mask=None):
thomwolf's avatar
thomwolf committed
748
749
        transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
        hidden_states = transformer_outputs[0]
750

thomwolf's avatar
thomwolf committed
751
752
        lm_logits = self.lm_head(hidden_states)
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
thomwolf's avatar
thomwolf committed
753

754
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
755
756
757
758
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
                            mc_labels.view(-1))
759
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
760
        if lm_labels is not None:
761
762
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
thomwolf's avatar
thomwolf committed
763
            loss_fct = CrossEntropyLoss(ignore_index=-1)
thomwolf's avatar
thomwolf committed
764
765
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1))
766
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
767
768

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)