modeling_gpt2.py 37.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

18
19
from __future__ import absolute_import, division, print_function, unicode_literals

thomwolf's avatar
thomwolf committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import collections
import copy
import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter

37
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
thomwolf's avatar
thomwolf committed
38
39
40
41
from .modeling import BertLayerNorm as LayerNorm

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
42
43
44
45
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
                                "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
                                 "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
thomwolf's avatar
thomwolf committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
    """ Load tf checkpoints in a pytorch model
    """
    try:
        import re
        import numpy as np
        import tensorflow as tf
    except ImportError:
        print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions.")
        raise
    tf_path = os.path.abspath(gpt2_checkpoint_path)
    print("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
    arrays = []
    for name, shape in init_vars:
        print("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
thomwolf's avatar
thomwolf committed
68
        arrays.append(array.squeeze())
thomwolf's avatar
thomwolf committed
69
70

    for name, array in zip(names, arrays):
thomwolf's avatar
thomwolf committed
71
        name = name[6:]  # skip "model/"
thomwolf's avatar
thomwolf committed
72
73
74
        name = name.split('/')
        pointer = model
        for m_name in name:
thomwolf's avatar
thomwolf committed
75
76
            if re.fullmatch(r'[A-Za-z]+\d+', m_name):
                l = re.split(r'(\d+)', m_name)
thomwolf's avatar
thomwolf committed
77
78
79
80
81
82
            else:
                l = [m_name]
            if l[0] == 'w' or l[0] == 'g':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'b':
                pointer = getattr(pointer, 'bias')
thomwolf's avatar
thomwolf committed
83
84
85
            elif l[0] == 'wpe' or l[0] == 'wte':
                pointer = getattr(pointer, l[0])
                pointer = getattr(pointer, 'weight')
thomwolf's avatar
thomwolf committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            else:
                pointer = getattr(pointer, l[0])
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        print("Initialize PyTorch weight {}".format(name))
        pointer.data = torch.from_numpy(array)
    return model


def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


class GPT2Config(object):
    """Configuration class to store the configuration of a `GPT2Model`.
    """

    def __init__(
        self,
thomwolf's avatar
thomwolf committed
111
        vocab_size_or_config_json_file=50257,
thomwolf's avatar
thomwolf committed
112
        n_special=0,
thomwolf's avatar
thomwolf committed
113
114
115
116
117
        n_positions=1024,
        n_ctx=1024,
        n_embd=768,
        n_layer=12,
        n_head=12,
118
119
120
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
thomwolf's avatar
thomwolf committed
121
122
        layer_norm_epsilon=1e-5,
        initializer_range=0.02,
123
        predict_special_tokens=True
thomwolf's avatar
thomwolf committed
124
125
126
127
128
    ):
        """Constructs GPT2Config.

        Args:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
thomwolf's avatar
thomwolf committed
129
            n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
thomwolf's avatar
thomwolf committed
130
131
132
133
134
135
136
            n_positions: Number of positional embeddings.
            n_ctx: Size of the causal mask (usually same as n_positions).
            n_embd: Dimensionality of the embeddings and hidden states.
            n_layer: Number of hidden layers in the Transformer encoder.
            n_head: Number of attention heads for each attention layer in
                the Transformer encoder.
            layer_norm_epsilon: epsilon to use in the layer norm layers
137
138
139
140
141
            resid_pdrop: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attn_pdrop: The dropout ratio for the attention
                probabilities.
            embd_pdrop: The dropout ratio for the embeddings.
thomwolf's avatar
thomwolf committed
142
143
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
144
            predict_special_tokens: should we predict special tokens (when the model has a LM head)
thomwolf's avatar
thomwolf committed
145
146
147
148
149
150
151
152
153
        """
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
thomwolf's avatar
thomwolf committed
154
            self.n_special = n_special
thomwolf's avatar
thomwolf committed
155
156
157
158
159
            self.n_ctx = n_ctx
            self.n_positions = n_positions
            self.n_embd = n_embd
            self.n_layer = n_layer
            self.n_head = n_head
160
161
162
            self.resid_pdrop = resid_pdrop
            self.embd_pdrop = embd_pdrop
            self.attn_pdrop = attn_pdrop
thomwolf's avatar
thomwolf committed
163
164
            self.layer_norm_epsilon = layer_norm_epsilon
            self.initializer_range = initializer_range
165
            self.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
166
167
168
169
170
171
        else:
            raise ValueError(
                "First argument must be either a vocabulary size (int)"
                "or the path to a pretrained model config file (str)"
            )

thomwolf's avatar
thomwolf committed
172
173
174
175
    @property
    def total_tokens_embeddings(self):
        return self.vocab_size + self.n_special

thomwolf's avatar
thomwolf committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `GPT2Config` from a Python dictionary of parameters."""
        config = GPT2Config(vocab_size_or_config_json_file=-1)
        for key, value in json_object.items():
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `GPT2Config` from a json file of parameters."""
        with open(json_file, "r", encoding="utf-8") as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

203
204
205
206
207
    def to_json_file(self, json_file_path):
        """ Save this instance to a json file."""
        with open(json_file_path, "w", encoding='utf-8') as writer:
            writer.write(self.to_json_string())

thomwolf's avatar
thomwolf committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

class Conv1D(nn.Module):
    def __init__(self, nf, nx):
        super(Conv1D, self).__init__()
        self.nf = nf
        w = torch.empty(nx, nf)
        nn.init.normal_(w, std=0.02)
        self.weight = Parameter(w)
        self.bias = Parameter(torch.zeros(nf))

    def forward(self, x):
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
        x = x.view(*size_out)
        return x


class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
226
    def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
thomwolf's avatar
thomwolf committed
227
228
229
230
231
232
233
234
        super(Attention, self).__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
thomwolf's avatar
thomwolf committed
235
        self.output_attentions = output_attentions
thomwolf's avatar
thomwolf committed
236
237
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
238
239
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
240
241
242
243
244

    def _attn(self, q, k, v):
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
thomwolf's avatar
thomwolf committed
245
246
        nd, ns = w.size(-2), w.size(-1)
        b = self.bias[:, :, ns-nd:ns, :ns]
247
        w = w * b - 1e4 * (1 - b)
thomwolf's avatar
thomwolf committed
248
249

        w = nn.Softmax(dim=-1)(w)
250
        w = self.attn_dropout(w)
thomwolf's avatar
thomwolf committed
251
252
        if self.output_attentions:
            return w, torch.matmul(w, v)
thomwolf's avatar
thomwolf committed
253
254
255
256
257
258
259
260
261
262
263
        return torch.matmul(w, v)

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
thomwolf's avatar
thomwolf committed
264
            return x.permute(0, 2, 3, 1)  # (batch, head, head_features, seq_length)
thomwolf's avatar
thomwolf committed
265
        else:
thomwolf's avatar
thomwolf committed
266
            return x.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)
thomwolf's avatar
thomwolf committed
267

thomwolf's avatar
thomwolf committed
268
    def forward(self, x, layer_past=None):
thomwolf's avatar
thomwolf committed
269
270
271
272
273
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
thomwolf's avatar
thomwolf committed
274
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
275
            past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]  # transpose back cf below
thomwolf's avatar
thomwolf committed
276
            key = torch.cat((past_key, key), dim=-1)
thomwolf's avatar
thomwolf committed
277
            value = torch.cat((past_value, value), dim=-2)
thomwolf's avatar
thomwolf committed
278
        present = torch.stack((key.transpose(-2, -1), value))  # transpose to have same shapes for stacking
thomwolf's avatar
thomwolf committed
279
        a = self._attn(query, key, value)
thomwolf's avatar
thomwolf committed
280
281
        if self.output_attentions:
            attentions, a = a
thomwolf's avatar
thomwolf committed
282
283
        a = self.merge_heads(a)
        a = self.c_proj(a)
284
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
285
286
        if self.output_attentions:
            return attentions, a, present
thomwolf's avatar
thomwolf committed
287
288
289
290
291
292
293
294
295
296
        return a, present


class MLP(nn.Module):
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
        super(MLP, self).__init__()
        nx = config.n_embd
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
        self.act = gelu
297
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
298
299
300
301

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
302
        return self.dropout(h2)
thomwolf's avatar
thomwolf committed
303
304
305


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
306
    def __init__(self, n_ctx, config, scale=False, output_attentions=False):
thomwolf's avatar
thomwolf committed
307
308
        super(Block, self).__init__()
        nx = config.n_embd
thomwolf's avatar
thomwolf committed
309
        self.output_attentions = output_attentions
thomwolf's avatar
thomwolf committed
310
        self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
311
        self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
thomwolf's avatar
thomwolf committed
312
313
314
        self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = MLP(4 * nx, config)

thomwolf's avatar
thomwolf committed
315
    def forward(self, x, layer_past=None):
thomwolf's avatar
thomwolf committed
316
317
318
319
320
        output_attn = self.attn(self.ln_1(x), layer_past=layer_past)
        if self.output_attentions:
            attentions, a, present = output_attn
        else:
            a, present = output_attn
thomwolf's avatar
thomwolf committed
321
        x = x + a
thomwolf's avatar
thomwolf committed
322
        m = self.mlp(self.ln_2(x))
thomwolf's avatar
thomwolf committed
323
        x = x + m
thomwolf's avatar
thomwolf committed
324
325
        if self.output_attentions:
            return attentions, x, present
thomwolf's avatar
thomwolf committed
326
327
328
329
330
331
332
333
334
        return x, present


class GPT2LMHead(nn.Module):
    """ Language Model Head for the transformer """

    def __init__(self, model_embeddings_weights, config):
        super(GPT2LMHead, self).__init__()
        self.n_embd = config.n_embd
335
336
        self.vocab_size = config.vocab_size
        self.predict_special_tokens = config.predict_special_tokens
thomwolf's avatar
thomwolf committed
337
338
339
340
        embed_shape = model_embeddings_weights.shape
        self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
        self.set_embeddings_weights(model_embeddings_weights)

341
342
    def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
        self.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
343
344
345
346
        self.decoder.weight = model_embeddings_weights  # Tied weights

    def forward(self, hidden_state):
        lm_logits = self.decoder(hidden_state)
347
348
        if not self.predict_special_tokens:
            lm_logits = lm_logits[..., :self.vocab_size]
thomwolf's avatar
thomwolf committed
349
350
351
352
353
354
355
356
357
        return lm_logits


class GPT2MultipleChoiceHead(nn.Module):
    """ Classifier Head for the transformer """

    def __init__(self, config):
        super(GPT2MultipleChoiceHead, self).__init__()
        self.n_embd = config.n_embd
358
        self.dropout = nn.Dropout2d(config.resid_pdrop)  # To reproduce the noise_shape parameter of TF implementation
thomwolf's avatar
thomwolf committed
359
360
361
362
363
364
365
366
367
368
369
370
371
        self.linear = nn.Linear(config.n_embd, 1)

        nn.init.normal_(self.linear.weight, std=0.02)
        nn.init.normal_(self.linear.bias, 0)

    def forward(self, hidden_states, mc_token_ids):
        # Classification logits
        # hidden_state (bsz, num_choices, seq_length, hidden_size)
        # mc_token_ids (bsz, num_choices)
        mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
        # (bsz, num_choices, 1, hidden_size)
        multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
        # (bsz, num_choices, hidden_size)
372
        multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
thomwolf's avatar
thomwolf committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
        multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
        # (bsz, num_choices)
        return multiple_choice_logits


class GPT2PreTrainedModel(nn.Module):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """

    def __init__(self, config, *inputs, **kwargs):
        super(GPT2PreTrainedModel, self).__init__()
        if not isinstance(config, GPT2Config):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
                "To create a model from a pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
                )
            )
        self.config = config

    def init_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @classmethod
VictorSanh's avatar
VictorSanh committed
409
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
thomwolf's avatar
thomwolf committed
410
411
412
413
414
415
416
        """
        Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.

        Params:
            pretrained_model_name_or_path: either:
                - a str with the name of a pre-trained model to load selected in the list of:
Joel Grus's avatar
Joel Grus committed
417
                    . `gpt2`
thomwolf's avatar
thomwolf committed
418
419
420
421
                - a path or url to a pretrained model archive containing:
                    . `gpt2_config.json` a configuration file for the model
                    . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance
                - a path or url to a pretrained model archive containing:
Joel Grus's avatar
Joel Grus committed
422
                    . `gpt2_config.json` a configuration file for the model
thomwolf's avatar
thomwolf committed
423
424
425
                    . a TensorFlow checkpoint with trained weights
            from_tf: should we load the weights from a locally saved TensorFlow checkpoint
            cache_dir: an optional path to a folder in which the pre-trained models will be cached.
Joel Grus's avatar
Joel Grus committed
426
            state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
VictorSanh's avatar
VictorSanh committed
427
            *inputs, **kwargs: additional input for the specific GPT2 class
thomwolf's avatar
thomwolf committed
428
        """
VictorSanh's avatar
VictorSanh committed
429
430
431
432
433
434
        state_dict = kwargs.get('state_dict', None)
        kwargs.pop('state_dict', None)
        cache_dir = kwargs.get('cache_dir', None)
        kwargs.pop('cache_dir', None)
        from_tf = kwargs.get('from_tf', False)
        kwargs.pop('from_tf', None)
435
436
        num_special_tokens = kwargs.get('num_special_tokens', None)
        kwargs.pop('num_special_tokens', None)
VictorSanh's avatar
VictorSanh committed
437

thomwolf's avatar
thomwolf committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
            archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
            config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
        else:
            archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
            config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
        # redirect to the cache, if necessary
        try:
            resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
            resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
        except EnvironmentError:
            logger.error(
                "Model name '{}' was not found in model name list ({}). "
                "We assumed '{}' was a path or url but couldn't find files {} and {} "
                "at this path or url.".format(
                    pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
                    archive_file, config_file
                )
            )
            return None
        if resolved_archive_file == archive_file and resolved_config_file == config_file:
            logger.info("loading weights file {}".format(archive_file))
            logger.info("loading configuration file {}".format(config_file))
        else:
            logger.info("loading weights file {} from cache at {}".format(
                archive_file, resolved_archive_file))
            logger.info("loading configuration file {} from cache at {}".format(
                config_file, resolved_config_file))
        # Load config
        config = GPT2Config.from_json_file(resolved_config_file)
        logger.info("Model config {}".format(config))
        # Instantiate model.
        model = cls(config, *inputs, **kwargs)
        if state_dict is None and not from_tf:
thomwolf's avatar
thomwolf committed
472
            state_dict = torch.load(resolved_archive_file, map_location='cpu')
thomwolf's avatar
thomwolf committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        if from_tf:
            # Directly load from a TensorFlow checkpoint (stored as NumPy array)
            return load_tf_weights_in_gpt2(model, resolved_archive_file)

        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if key.endswith(".g"):
                new_key = key[:-2] + ".weight"
            elif key.endswith(".b"):
                new_key = key[:-2] + ".bias"
            elif key.endswith(".w"):
                new_key = key[:-2] + ".weight"
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, "_metadata", None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=""):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
            )
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + ".")

        start_model = model
        if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
            start_model = model.transformer
        load(start_model, prefix="")

        if len(missing_keys) > 0:
            logger.info(
                "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
            )
        if len(unexpected_keys) > 0:
            logger.info(
                "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
            )
        if len(error_msgs) > 0:
            raise RuntimeError(
                "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
            )

thomwolf's avatar
thomwolf committed
529
530
531
        # Add additional embeddings for special tokens if needed
        # This step also make sure we are still sharing the output and input embeddings after loading weights
        model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
thomwolf's avatar
thomwolf committed
532
533
534
535
536
537
        return model


class GPT2Model(GPT2PreTrainedModel):
    """OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").

thomwolf's avatar
thomwolf committed
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
    GPT-2 use a single embedding matrix to store the word and special embeddings.
    Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
    Special tokens need to be trained during the fine-tuning if you use them.
    The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.

    The embeddings are ordered as follow in the token embeddings matrice:
        [0,                                                         ----------------------
         ...                                                        -> word embeddings
         config.vocab_size - 1,                                     ______________________
         config.vocab_size,
         ...                                                        -> special embeddings
         config.vocab_size + config.n_special - 1]                  ______________________

    where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
        total_tokens_embeddings = config.vocab_size + config.n_special
    You should use the associate indices to index the embeddings.

thomwolf's avatar
thomwolf committed
555
556
557
558
559
560
561
562
563
564
565
566
567
    Params:
        config: a GPT2Config class instance with the configuration to build a new model

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
            were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
        `position_ids`: an optional torch.LongTensor with the same shape as input_ids
            with the position indices (selected in the range [0, config.n_positions - 1[.
        `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
            You can use it to add a third type of embedding to each input token in the sequence
            (the previous two being the word and position embeddings).
            The input, position and token_type embeddings are summed inside the Transformer before the first
            self-attention block.
Joel Grus's avatar
Joel Grus committed
568
569
570
        `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
            (key and values in the attention blocks) to speed up sequential decoding
            (this is the presents output of the model, cf. below).
thomwolf's avatar
thomwolf committed
571

Joel Grus's avatar
Joel Grus committed
572
    Outputs a tuple consisting of:
thomwolf's avatar
thomwolf committed
573
574
575
        `hidden_states`: the encoded-hidden-states at the top of the model
            as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
            (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Joel Grus's avatar
Joel Grus committed
576
577
        `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
            torch.FloatTensors. They can be reused to speed up sequential decoding.
thomwolf's avatar
thomwolf committed
578
579
580
581
582
583
584
585
586

    Example usage:
    ```python
    # Already been converted into BPE token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])

    config = modeling_gpt2.GPT2Config()

    model = modeling_gpt2.GPT2Model(config)
Joel Grus's avatar
Joel Grus committed
587
    hidden_states, presents = model(input_ids)
thomwolf's avatar
thomwolf committed
588
589
590
    ```
    """

thomwolf's avatar
thomwolf committed
591
    def __init__(self, config, output_attentions=False):
thomwolf's avatar
thomwolf committed
592
        super(GPT2Model, self).__init__(config)
thomwolf's avatar
thomwolf committed
593
        self.output_attentions = output_attentions
thomwolf's avatar
thomwolf committed
594
        self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
thomwolf's avatar
thomwolf committed
595
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
596
        self.drop = nn.Dropout(config.embd_pdrop)
thomwolf's avatar
thomwolf committed
597
        block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
thomwolf's avatar
thomwolf committed
598
        self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
thomwolf's avatar
thomwolf committed
599
        self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
600
601
602

        self.apply(self.init_weights)

thomwolf's avatar
thomwolf committed
603
604
605
606
607
608
609
610
611
612
613
614
615
616
    def set_num_special_tokens(self, num_special_tokens):
        " Update input embeddings with new embedding matrice if needed "
        if self.config.n_special == num_special_tokens:
            return
        # Update config
        self.config.n_special = num_special_tokens
        # Build new embeddings and initialize all new embeddings (in particular the special tokens)
        old_embed = self.wte
        self.wte = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
        self.wte.to(old_embed.weight.device)
        self.init_weights(self.wte)
        # Copy word embeddings from the previous weights
        self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]

thomwolf's avatar
thomwolf committed
617
618
    def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
        if past is None:
thomwolf's avatar
thomwolf committed
619
            past_length = 0
thomwolf's avatar
thomwolf committed
620
            past = [None] * len(self.h)
thomwolf's avatar
thomwolf committed
621
        else:
thomwolf's avatar
thomwolf committed
622
            past_length = past[0][0].size(-2)
thomwolf's avatar
thomwolf committed
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        if position_ids is None:
            position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_ids.size(-1))
        position_ids = position_ids.view(-1, position_ids.size(-1))

        inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
            token_type_embeds = self.wte(token_type_ids)
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
639
640
        hidden_states = self.drop(hidden_states)

thomwolf's avatar
thomwolf committed
641
        presents = []
thomwolf's avatar
thomwolf committed
642
        all_attentions = []
thomwolf's avatar
thomwolf committed
643
        for block, layer_past in zip(self.h, past):
thomwolf's avatar
thomwolf committed
644
645
646
647
648
            if self.output_attentions:
                attentions, hidden_states, present = block(hidden_states, layer_past)
                all_attentions.append(attentions)
            else:
                hidden_states, present = block(hidden_states, layer_past)
thomwolf's avatar
thomwolf committed
649
650
651
            presents.append(present)
        hidden_states = self.ln_f(hidden_states)
        output_shape = input_shape + (hidden_states.size(-1),)
thomwolf's avatar
thomwolf committed
652
653
        if self.output_attentions:
            return all_attentions, hidden_states.view(*output_shape), presents
thomwolf's avatar
thomwolf committed
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        return hidden_states.view(*output_shape), presents


class GPT2LMHeadModel(GPT2PreTrainedModel):
    """OpenAI GPT-2 model with a Language Modeling head ("Language Models are Unsupervised Multitask Learners").

    Params:
        config: a GPT2Config class instance with the configuration to build a new model

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
            were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
        `position_ids`: an optional torch.LongTensor with the same shape as input_ids
            with the position indices (selected in the range [0, config.n_positions - 1[.
        `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
            You can use it to add a third type of embedding to each input token in the sequence
            (the previous two being the word and position embeddings).
            The input, position and token_type embeddings are summed inside the Transformer before the first
            self-attention block.
        `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
            with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
            is only computed for the labels set in [0, ..., vocab_size]
Joel Grus's avatar
Joel Grus committed
676
677
678
        `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
            (key and values in the attention blocks) to speed up sequential decoding
            (this is the presents output of the model, cf. below).
thomwolf's avatar
thomwolf committed
679
680
681
682

    Outputs:
        if `lm_labels` is not `None`:
            Outputs the language modeling loss.
Joel Grus's avatar
Joel Grus committed
683
        else a tuple:
thomwolf's avatar
thomwolf committed
684
685
            `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, config.vocab_size]
                (or more generally [d_1, ..., d_n, config.vocab_size] were d_1 ... d_n are the dimension of input_ids)
Joel Grus's avatar
Joel Grus committed
686
687
            `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
                torch.FloatTensors. They can be reused to speed up sequential decoding.
thomwolf's avatar
thomwolf committed
688
689
690
691
692
693
694
695
696

    Example usage:
    ```python
    # Already been converted into BPE token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])

    config = modeling_gpt2.GPT2Config()

    model = modeling_gpt2.GPT2LMHeadModel(config)
Joel Grus's avatar
Joel Grus committed
697
    lm_logits, presents = model(input_ids)
thomwolf's avatar
thomwolf committed
698
699
700
    ```
    """

thomwolf's avatar
thomwolf committed
701
    def __init__(self, config, output_attentions=False):
thomwolf's avatar
thomwolf committed
702
        super(GPT2LMHeadModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
703
        self.transformer = GPT2Model(config, output_attentions=output_attentions)
thomwolf's avatar
thomwolf committed
704
705
706
        self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
        self.apply(self.init_weights)

707
    def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
thomwolf's avatar
thomwolf committed
708
709
        """ Update input and output embeddings with new embedding matrice
            Make sure we are sharing the embeddings
thomwolf's avatar
thomwolf committed
710
        """
711
        self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
712
        self.transformer.set_num_special_tokens(num_special_tokens)
713
        self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
thomwolf's avatar
thomwolf committed
714

thomwolf's avatar
thomwolf committed
715
    def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
thomwolf's avatar
thomwolf committed
716
717
718
719
720
        transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
        if self.transformer.output_attentions:
            all_attentions, hidden_states, presents = transformer_output
        else:
            hidden_states, presents = transformer_output
thomwolf's avatar
thomwolf committed
721
722
        lm_logits = self.lm_head(hidden_states)
        if lm_labels is not None:
723
            # Shift so that tokens < n predict n
724
725
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
726
            # Flatten the tokens
thomwolf's avatar
thomwolf committed
727
            loss_fct = CrossEntropyLoss(ignore_index=-1)
728
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
729
                            shift_labels.view(-1))
thomwolf's avatar
thomwolf committed
730
            return loss
thomwolf's avatar
thomwolf committed
731
732
        if self.transformer.output_attentions:
            return all_attentions, lm_logits, presents
thomwolf's avatar
thomwolf committed
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
        return lm_logits, presents


class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
    """OpenAI GPT-2 model with a Language Modeling and a Multiple Choice head ("Language Models are Unsupervised Multitask Learners").

    Params:
        config: a GPT2Config class instance with the configuration to build a new model

    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the BPE token
            indices selected in the range [0, config.vocab_size[
        `mc_token_ids`: a torch.LongTensor of shape [batch_size, num_choices] with the index of the token from
            which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
        `position_ids`: an optional torch.LongTensor with the same shape as input_ids
            with the position indices (selected in the range [0, config.n_positions - 1[.
        `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
            You can use it to add a third type of embedding to each input token in the sequence
            (the previous two being the word and position embeddings).
            The input, position and token_type embeddings are summed inside the Transformer before the first
            self-attention block.
        `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
            with indices selected in [-1, 0, ..., config.vocab_size]. All labels set to -1 are ignored (masked), the loss
            is only computed for the labels set in [0, ..., config.vocab_size]
        `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
            with indices selected in [0, ..., num_choices].
Joel Grus's avatar
Joel Grus committed
759
760
761
        `past`: an optional list of torch.LongTensor that contains pre-computed hidden-states
            (key and values in the attention blocks) to speed up sequential decoding
            (this is the presents output of the model, cf. below).
thomwolf's avatar
thomwolf committed
762
763
764
765
766
767
768

    Outputs:
        if `lm_labels` and `multiple_choice_labels` are not `None`:
            Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
        else: a tuple with
            `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, config.vocab_size]
            `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
Joel Grus's avatar
Joel Grus committed
769
770
            `presents`: a list of pre-computed hidden-states (key and values in each attention blocks) as
                torch.FloatTensors. They can be reused to speed up sequential decoding.
thomwolf's avatar
thomwolf committed
771
772
773
774
775
776
777
778
779

    Example usage:
    ```python
    # Already been converted into BPE token ids
    input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]])  # (bsz, number of choice, seq length)
    mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)

    config = modeling_gpt2.GPT2Config()

VictorSanh's avatar
VictorSanh committed
780
    model = modeling_gpt2.GPT2DoubleHeadsModel(config)
Joel Grus's avatar
Joel Grus committed
781
    lm_logits, multiple_choice_logits, presents = model(input_ids, mc_token_ids)
thomwolf's avatar
thomwolf committed
782
783
784
    ```
    """

thomwolf's avatar
thomwolf committed
785
    def __init__(self, config, output_attentions=False):
thomwolf's avatar
thomwolf committed
786
        super(GPT2DoubleHeadsModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
787
        self.transformer = GPT2Model(config, output_attentions=output_attentions)
thomwolf's avatar
thomwolf committed
788
789
790
791
        self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
        self.multiple_choice_head = GPT2MultipleChoiceHead(config)
        self.apply(self.init_weights)

792
    def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
thomwolf's avatar
thomwolf committed
793
794
        """ Update input and output embeddings with new embedding matrice
            Make sure we are sharing the embeddings
thomwolf's avatar
thomwolf committed
795
        """
796
        self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
797
        self.transformer.set_num_special_tokens(num_special_tokens)
798
        self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
thomwolf's avatar
thomwolf committed
799

thomwolf's avatar
thomwolf committed
800
    def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
thomwolf's avatar
thomwolf committed
801
802
803
804
805
        transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
        if self.transformer.output_attentions:
            all_attentions, hidden_states, presents = transformer_output
        else:
            hidden_states, presents = transformer_output
thomwolf's avatar
thomwolf committed
806
807
808
809
        lm_logits = self.lm_head(hidden_states)
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
        losses = []
        if lm_labels is not None:
810
811
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
thomwolf's avatar
thomwolf committed
812
            loss_fct = CrossEntropyLoss(ignore_index=-1)
813
            losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
thomwolf's avatar
thomwolf committed
814
815
816
817
818
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
            losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
        if losses:
            return losses
thomwolf's avatar
thomwolf committed
819
820
        if self.transformer.output_attentions:
            return all_attentions, lm_logits, mc_logits, presents
thomwolf's avatar
thomwolf committed
821
        return lm_logits, mc_logits, presents