"lightx2v/models/vscode:/vscode.git/clone" did not exist on "84ece5f581e864152b3d1ef10fec06358b706b3f"
modeling_openai.py 34.1 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT model."""

18
19
from __future__ import absolute_import, division, print_function, unicode_literals

20
import collections
thomwolf's avatar
thomwolf committed
21
import json
thomwolf's avatar
thomwolf committed
22
import logging
23
24
import math
import os
thomwolf's avatar
thomwolf committed
25
26
import sys
from io import open
thomwolf's avatar
thomwolf committed
27
28
29

import torch
import torch.nn as nn
thomwolf's avatar
thomwolf committed
30
from torch.nn import CrossEntropyLoss
thomwolf's avatar
thomwolf committed
31
32
from torch.nn.parameter import Parameter

33
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
34
                             PreTrainedModel, prune_conv1d_layer, SequenceSummary)
thomwolf's avatar
thomwolf committed
35
from .modeling_bert import BertLayerNorm as LayerNorm
thomwolf's avatar
thomwolf committed
36

thomwolf's avatar
thomwolf committed
37
38
logger = logging.getLogger(__name__)

39
40
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"}
41

42

43
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
44
45
    """ Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
    """
46
47
    import re
    import numpy as np
48
49
50
51
52
53

    if '.ckpt' in openai_checkpoint_folder_path:
        openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)

    logger.info("Loading weights from {}".format(openai_checkpoint_folder_path))

54
55
56
57
58
59
60
    names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
    shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
    offsets = np.cumsum([np.prod(shape) for shape in shapes])
    init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
    init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
    init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]

thomwolf's avatar
thomwolf committed
61
    # This was used when we had a single embedding matrix for positions and tokens
62
63
    # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
    # del init_params[1]
64
65
66
    init_params = [arr.squeeze() for arr in init_params]

    try:
67
68
        assert model.tokens_embed.weight.shape == init_params[1].shape
        assert model.positions_embed.weight.shape == init_params[0].shape
69
    except AssertionError as e:
70
71
        e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
        e.args += (model.positions_embed.weight.shape, init_params[0].shape)
72
73
        raise

74
75
    model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
    model.positions_embed.weight.data = torch.from_numpy(init_params[0])
76
    names.pop(0)
77
78
    # Pop position and token embedding arrays
    init_params.pop(0)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    init_params.pop(0)

    for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
        name = name[6:]  # skip "model/"
        assert name[-2:] == ":0"
        name = name[:-2]
        name = name.split('/')
        pointer = model
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+\d+', m_name):
                l = re.split(r'(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'g':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'b':
                pointer = getattr(pointer, 'bias')
            elif l[0] == 'w':
                pointer = getattr(pointer, 'weight')
            else:
                pointer = getattr(pointer, l[0])
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
113
        logger.info("Initialize PyTorch weight {}".format(name))
114
115
116
        pointer.data = torch.from_numpy(array)
    return model

thomwolf's avatar
thomwolf committed
117
118
119
120
121
122
123
124
125

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


def swish(x):
    return x * torch.sigmoid(x)


126
127
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}

thomwolf's avatar
thomwolf committed
128

129
class OpenAIGPTConfig(PretrainedConfig):
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    """
    Configuration class to store the configuration of a `OpenAIGPTModel`.

    Args:
        vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
        n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
        n_positions: Number of positional embeddings.
        n_ctx: Size of the causal mask (usually same as n_positions).
        n_embd: Dimensionality of the embeddings and hidden states.
        n_layer: Number of hidden layers in the Transformer encoder.
        n_head: Number of attention heads for each attention layer in
            the Transformer encoder.
        afn: The non-linear activation function (function or string) in the
            encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
        resid_pdrop: The dropout probabilitiy for all fully connected
            layers in the embeddings, encoder, and pooler.
        attn_pdrop: The dropout ratio for the attention
            probabilities.
        embd_pdrop: The dropout ratio for the embeddings.
        layer_norm_epsilon: epsilon to use in the layer norm layers
        initializer_range: The sttdev of the truncated_normal_initializer for
            initializing all weight matrices.
        predict_special_tokens: should we predict special tokens (when the model has a LM head)
thomwolf's avatar
thomwolf committed
153
    """
154
    pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
155
156
157
158

    def __init__(
        self,
        vocab_size_or_config_json_file=40478,
thomwolf's avatar
thomwolf committed
159
        n_positions=512,
160
161
162
163
164
165
166
167
        n_ctx=512,
        n_embd=768,
        n_layer=12,
        n_head=12,
        afn="gelu",
        resid_pdrop=0.1,
        embd_pdrop=0.1,
        attn_pdrop=0.1,
168
        layer_norm_epsilon=1e-5,
169
        initializer_range=0.02,
thomwolf's avatar
thomwolf committed
170
        predict_special_tokens=True,
thomwolf's avatar
thomwolf committed
171
172

        num_labels=1,
thomwolf's avatar
thomwolf committed
173
174
175
        summary_type='token_ids',
        summary_use_proj=True,
        summary_activation=None,
thomwolf's avatar
thomwolf committed
176
        summary_proj_to_labels=True,
177
        summary_first_dropout=0.1,
thomwolf's avatar
thomwolf committed
178
        **kwargs
179
    ):
thomwolf's avatar
thomwolf committed
180
181
        """Constructs OpenAIGPTConfig.
        """
thomwolf's avatar
thomwolf committed
182
183
        super(OpenAIGPTConfig, self).__init__(**kwargs)

thomwolf's avatar
thomwolf committed
184
185
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
186
            with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
thomwolf's avatar
thomwolf committed
187
188
189
190
191
192
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
            self.n_ctx = n_ctx
thomwolf's avatar
thomwolf committed
193
            self.n_positions = n_positions
thomwolf's avatar
thomwolf committed
194
195
196
197
198
199
200
            self.n_embd = n_embd
            self.n_layer = n_layer
            self.n_head = n_head
            self.afn = afn
            self.resid_pdrop = resid_pdrop
            self.embd_pdrop = embd_pdrop
            self.attn_pdrop = attn_pdrop
201
            self.layer_norm_epsilon = layer_norm_epsilon
thomwolf's avatar
thomwolf committed
202
            self.initializer_range = initializer_range
203
            self.predict_special_tokens = predict_special_tokens
thomwolf's avatar
thomwolf committed
204
205

            self.num_labels = num_labels
thomwolf's avatar
thomwolf committed
206
207
208
            self.summary_type = summary_type
            self.summary_use_proj = summary_use_proj
            self.summary_activation = summary_activation
209
            self.summary_first_dropout = summary_first_dropout
thomwolf's avatar
thomwolf committed
210
            self.summary_proj_to_labels = summary_proj_to_labels
thomwolf's avatar
thomwolf committed
211
        else:
212
213
214
215
            raise ValueError(
                "First argument must be either a vocabulary size (int)"
                "or the path to a pretrained model config file (str)"
            )
thomwolf's avatar
thomwolf committed
216

217
218
219
220
    @property
    def max_position_embeddings(self):
        return self.n_positions

thomwolf's avatar
thomwolf committed
221
222
223
224
225
226
227
228
229
230
231
232
    @property
    def hidden_size(self):
        return self.n_embd

    @property
    def num_attention_heads(self):
        return self.n_head

    @property
    def num_hidden_layers(self):
        return self.n_layer

thomwolf's avatar
thomwolf committed
233
234

class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
235
    def __init__(self, nx, n_ctx, config, scale=False):
thomwolf's avatar
thomwolf committed
236
237
238
        super(Attention, self).__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
239
        assert n_state % config.n_head == 0
thomwolf's avatar
thomwolf committed
240
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
241
        self.n_head = config.n_head
thomwolf's avatar
thomwolf committed
242
243
        self.split_size = n_state
        self.scale = scale
244

thomwolf's avatar
thomwolf committed
245
        self.output_attentions = config.output_attentions
246

247
248
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
249
250
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
251

252
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
253
254
        if len(heads) == 0:
            return
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
        for head in heads:
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)

    def _attn(self, q, k, v, head_mask=None):
thomwolf's avatar
thomwolf committed
269
270
271
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
thomwolf's avatar
thomwolf committed
272
        # w = w * self.bias + -1e9 * (1 - self.bias)  # TF implem method: mask_attn_weights
thomwolf's avatar
thomwolf committed
273
        # XD: self.b may be larger than w, so we need to crop it
thomwolf's avatar
thomwolf committed
274
        b = self.bias[:, :, : w.size(-2), : w.size(-1)]
thomwolf's avatar
thomwolf committed
275
276
        w = w * b + -1e9 * (1 - b)

thomwolf's avatar
thomwolf committed
277
278
        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)
279
280
281
282
283

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
284
        outputs = [torch.matmul(w, v)]
thomwolf's avatar
thomwolf committed
285
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
286
287
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

302
    def forward(self, x, head_mask=None):
thomwolf's avatar
thomwolf committed
303
304
305
306
307
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
308

thomwolf's avatar
thomwolf committed
309
310
        attn_outputs = self._attn(query, key, value, head_mask)
        a = attn_outputs[0]
311

thomwolf's avatar
thomwolf committed
312
313
314
        a = self.merge_heads(a)
        a = self.c_proj(a)
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
315
316
317

        outputs = [a] + attn_outputs[1:]
        return outputs  # a, (attentions)
thomwolf's avatar
thomwolf committed
318
319
320


class MLP(nn.Module):
321
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
thomwolf's avatar
thomwolf committed
322
        super(MLP, self).__init__()
323
        nx = config.n_embd
324
325
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
326
327
        self.act = ACT_FNS[config.afn]
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
328
329
330
331
332
333
334
335

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
336
    def __init__(self, n_ctx, config, scale=False):
thomwolf's avatar
thomwolf committed
337
        super(Block, self).__init__()
338
        nx = config.n_embd
thomwolf's avatar
thomwolf committed
339
        self.attn = Attention(nx, n_ctx, config, scale)
340
        self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
341
        self.mlp = MLP(4 * nx, config)
342
        self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
343

344
    def forward(self, x, head_mask=None):
thomwolf's avatar
thomwolf committed
345
346
347
        attn_outputs = self.attn(x, head_mask=head_mask)
        a = attn_outputs[0]

thomwolf's avatar
thomwolf committed
348
349
350
        n = self.ln_1(x + a)
        m = self.mlp(n)
        h = self.ln_2(n + m)
thomwolf's avatar
thomwolf committed
351
352
353

        outputs = [h] + attn_outputs[1:]
        return outputs
thomwolf's avatar
thomwolf committed
354
355


356
class OpenAIGPTPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
357
358
359
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
360
    config_class = OpenAIGPTConfig
361
    pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
362
363
    load_tf_weights = load_tf_weights_in_openai_gpt
    base_model_prefix = "transformer"
364

365
366
367
    def __init__(self, *inputs, **kwargs):
        super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs)

thomwolf's avatar
thomwolf committed
368
369
370
    def init_weights(self, module):
        """ Initialize the weights.
        """
371
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
372
373
374
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
375
376
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
thomwolf's avatar
thomwolf committed
377
378
379
        elif isinstance(module, LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
thomwolf's avatar
thomwolf committed
380
381


thomwolf's avatar
thomwolf committed
382
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
383
384
    """OpenAI GPT model ("Improving Language Understanding by Generative Pre-Training").

385
386
387
    OpenAI GPT uses a single embedding matrix to store the word and special embeddings.
    Special tokens embeddings are additional tokens that are not pre-trained, such as: [SEP], [CLS]...

388
    Special tokens need to be trained during the fine-tuning if you use them.
389
390
391
392
393
    The number of special embeddings can be controlled using the ``set_num_special_tokens(num_special_tokens)`` function.

    The embeddings are ordered as follow in the token embeddings matrix:

    ::
394

395
396
397
398
399
        [0,                                                         ----------------------
         ...                                                        -> word embeddings
         config.vocab_size - 1,                                     ______________________
         config.vocab_size,
         ...                                                        -> special embeddings
thomwolf's avatar
thomwolf committed
400
         config.vocab_size + n_special - 1]                  ______________________
401

thomwolf's avatar
thomwolf committed
402
    where ``total_tokens_embeddings``  is:
403
404
405

    ::

thomwolf's avatar
thomwolf committed
406
        total_tokens_embeddings = config.vocab_size + n_special
407

408
409
410
    You should use the associated indices to index the embeddings.

    Args:
411
412
413
414
        `config`: a OpenAIGPTConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
415
416


417
    Example::
418

419
420
        config = modeling_openai.OpenAIGPTConfig()
        model = modeling_openai.OpenAIGPTModel(config)
421
    """
422

thomwolf's avatar
thomwolf committed
423
    def __init__(self, config):
424
        super(OpenAIGPTModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
425
426
427
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

thomwolf's avatar
thomwolf committed
428
        self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
429
        self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
430
        self.drop = nn.Dropout(config.embd_pdrop)
431
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
thomwolf's avatar
thomwolf committed
432

thomwolf's avatar
thomwolf committed
433
434
        self.apply(self.init_weights)

thomwolf's avatar
thomwolf committed
435
436
    def _resize_token_embeddings(self, new_num_tokens):
        self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
thomwolf's avatar
thomwolf committed
437
        return self.tokens_embed
thomwolf's avatar
thomwolf committed
438

thomwolf's avatar
thomwolf committed
439
    def _prune_heads(self, heads_to_prune):
440
441
442
443
444
445
446
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

    def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None):
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        """
        Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**

        Args:
            `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
                were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
            `position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
                with the position indices (selected in the range [0, config.n_positions - 1[.
            `token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
                You can use it to add a third type of embedding to each input token in the sequence
                (the previous two being the word and position embeddings).
                The input, position and token_type embeddings are summed inside the Transformer before the first
                self-attention block.
            `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.

        Returns:
            ``hidden_states``, a list of all the encoded-hidden-states in the model (length of the list is number
            of layers + 1 for the output of the embeddings)
            as ``torch.FloatTensor`` of size [batch_size, sequence_length, hidden_size]
            (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)

        Example::

            # Already been converted into BPE token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])

            hidden_states = model(input_ids)
            # or
            hidden_states = model.forward(input_ids)
        """
thomwolf's avatar
thomwolf committed
478
        if position_ids is None:
479
480
481
482
483
            # This was used when we had a single embedding matrice from position and token embeddings
            # start = self.config.vocab_size + self.config.n_special
            # end = start + input_ids.size(-1)
            # position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
            position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
thomwolf's avatar
thomwolf committed
484
485
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

486
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
487
        # 1.0 in head_mask indicate we keep the head
488
        # attention_probs has shape bsz x n_heads x N x N
489
        # head_mask has shape n_layer x batch x n_heads x N x N
490
491
        if head_mask is not None:
            if head_mask.dim() == 1:
492
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
493
                head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
494
            elif head_mask.dim() == 2:
495
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
496
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
497
498
        else:
            head_mask = [None] * self.config.n_layer
499

thomwolf's avatar
thomwolf committed
500
501
502
503
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_ids.size(-1))
        position_ids = position_ids.view(-1, position_ids.size(-1))

504
505
        inputs_embeds = self.tokens_embed(input_ids)
        position_embeds = self.positions_embed(position_ids)
thomwolf's avatar
thomwolf committed
506
507
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
508
            token_type_embeds = self.tokens_embed(token_type_ids)
thomwolf's avatar
thomwolf committed
509
510
511
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
512
513
        hidden_states = self.drop(hidden_states)

514
515
        output_shape = input_shape + (hidden_states.size(-1),)

516
517
        all_attentions = ()
        all_hidden_states = ()
518
        for i, block in enumerate(self.h):
thomwolf's avatar
thomwolf committed
519
            if self.output_hidden_states:
520
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
521

522
            outputs = block(hidden_states, head_mask[i])
thomwolf's avatar
thomwolf committed
523
            hidden_states = outputs[0]
thomwolf's avatar
thomwolf committed
524
            if self.output_attentions:
525
                all_attentions = all_attentions + (outputs[1],)
thomwolf's avatar
thomwolf committed
526
527
528

        # Add last layer
        if self.output_hidden_states:
529
            all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
530

531
        outputs = (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
532
        if self.output_hidden_states:
533
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
534
        if self.output_attentions:
535
            outputs = outputs + (all_attentions,)
thomwolf's avatar
thomwolf committed
536
        return outputs  # last hidden state, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
537

538

thomwolf's avatar
thomwolf committed
539
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
540
541
    """OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").

542
543
    OpenAI GPT use a single embedding matrix to store the word and special embeddings.
    Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
544
545
546
547
548
549
    Special tokens need to be trained during the fine-tuning if you use them. The number of special embeddings
    can be controlled using the ``set_num_special_tokens(num_special_tokens)`` function.

    The embeddings are ordered as follow in the token embeddings matrix:

    ::
550

551
552
553
554
555
        [0,                                                         ----------------------
         ...                                                        -> word embeddings
         config.vocab_size - 1,                                     ______________________
         config.vocab_size,
         ...                                                        -> special embeddings
556
         config.vocab_size + config.n_special - 1]                  ______________________
557

558
559
560
561
    where ``total_tokens_embeddings`` can be obtained as ``config.total_tokens_embeddings`` and is:

    ::

562
        total_tokens_embeddings = config.vocab_size + config.n_special
563

564
565
566
    You should use the associated indices to index the embeddings.

    Args:
567
568
569
570
        `config`: a OpenAIGPTConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
571
572


573
    Example::
574

575
576
        config = modeling_openai.OpenAIGPTConfig()
        model = modeling_openai.OpenAIGPTLMHeadModel(config)
577
    """
578

thomwolf's avatar
thomwolf committed
579
    def __init__(self, config):
580
        super(OpenAIGPTLMHeadModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
581
        self.transformer = OpenAIGPTModel(config)
thomwolf's avatar
thomwolf committed
582
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
583

thomwolf's avatar
thomwolf committed
584
585
        self.apply(self.init_weights)
        self.tie_weights()
586

thomwolf's avatar
thomwolf committed
587
588
589
    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
590
        """
thomwolf's avatar
thomwolf committed
591
592
        self._tie_or_clone_weights(self.lm_head,
                                   self.transformer.tokens_embed)
thomwolf's avatar
thomwolf committed
593

594
    def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        """
        Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**

        Args:
            `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
                were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
            `position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
                with the position indices (selected in the range [0, config.n_positions - 1[.
            `token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
                You can use it to add a third type of embedding to each input token in the sequence
                (the previous two being the word and position embeddings).
                The input, position and token_type embeddings are summed inside the Transformer before the first
                self-attention block.
            `lm_labels`: optional language modeling labels: ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
                is only computed for the labels set in [0, ..., vocab_size]
            `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.

        Returns:
            if ``lm_labels`` is not ``None``, outputs the language modeling loss. Otherwise, outputs ``lm_logits``,
            the language modeling logits as a ``torch.FloatTensor`` of size [batch_size, sequence_length,
            total_tokens_embeddings] (or more generally [d_1, ..., d_n, total_tokens_embeddings] where d_1 ... d_n are
            the dimension of input_ids)

        Example::

            # Already been converted into BPE token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])

            lm_logits = model(input_ids)
            # or
            lm_logits = model.forward(input_ids)
        """
thomwolf's avatar
thomwolf committed
629
630
        transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
        hidden_states = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
631
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
632

633
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
634
        if lm_labels is not None:
635
            # Shift so that tokens < n predict n
thomwolf's avatar
thomwolf committed
636
637
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
638
            # Flatten the tokens
thomwolf's avatar
thomwolf committed
639
            loss_fct = CrossEntropyLoss(ignore_index=-1)
640
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
641
                            shift_labels.view(-1))
642
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
643
644

        return outputs  # (loss), lm_logits, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
645

646

thomwolf's avatar
thomwolf committed
647
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
thomwolf's avatar
thomwolf committed
648
    """OpenAI GPT model with a Language Modeling and a Multiple Choice head ("Improving Language Understanding by Generative Pre-Training").
649

650
651
652
    OpenAI GPT use a single embedding matrix to store the word and special embeddings.
    Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
    Special tokens need to be trained during the fine-tuning if you use them.
653
654
655
656
657
658
    The number of special embeddings can be controlled using the ``set_num_special_tokens(num_special_tokens)``
    function.

    The embeddings are ordered as follow in the token embeddings matrix:

    ::
659

660
661
662
663
664
        [0,                                                         ----------------------
         ...                                                        -> word embeddings
         config.vocab_size - 1,                                     ______________________
         config.vocab_size,
         ...                                                        -> special embeddings
thomwolf's avatar
thomwolf committed
665
         config.vocab_size + n_special - 1]                  ______________________
666

thomwolf's avatar
thomwolf committed
667
    where ``total_tokens_embeddings`` is:
668
669
670

    ::

thomwolf's avatar
thomwolf committed
671
        total_tokens_embeddings = config.vocab_size + .n_special
672

673
    You should use the associate indices to index the embeddings.
674

675
    Args:
676
677
678
679
        `config`: a OpenAIGPTConfig class instance with the configuration to build a new model
        `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
        `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
            This can be used to compute head importance metrics. Default: False
680

681
682
683
684
    Example::

        config = modeling_openai.OpenAIGPTConfig()
        model = modeling_openai.OpenAIGPTDoubleHeadsModel(config)
685
    """
686

thomwolf's avatar
thomwolf committed
687
    def __init__(self, config):
688
        super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
689

thomwolf's avatar
thomwolf committed
690
        self.transformer = OpenAIGPTModel(config)
thomwolf's avatar
thomwolf committed
691
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
692
693
        self.multiple_choice_head = SequenceSummary(config)

thomwolf's avatar
thomwolf committed
694
        self.apply(self.init_weights)
thomwolf's avatar
thomwolf committed
695
        self.tie_weights()
thomwolf's avatar
thomwolf committed
696

thomwolf's avatar
thomwolf committed
697
698
699
    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
700
        """
thomwolf's avatar
thomwolf committed
701
702
        self._tie_or_clone_weights(self.lm_head,
                                   self.transformer.tokens_embed)
thomwolf's avatar
thomwolf committed
703

thomwolf's avatar
thomwolf committed
704
    def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
705
                position_ids=None, head_mask=None):
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        """
        Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**

        Args:
            `input_ids`: a ``torch.LongTensor`` of shape [batch_size, num_choices, sequence_length] with the BPE token
                indices selected in the range [0, total_tokens_embeddings[
            `mc_token_ids`: a ``torch.LongTensor`` of shape [batch_size, num_choices] with the index of the token from
                which we should take the hidden state to feed the multiple choice classifier (usually last token of the sequence)
            `position_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
                with the position indices (selected in the range [0, config.n_positions - 1[.
            `token_type_ids`: an optional ``torch.LongTensor`` with the same shape as input_ids
                You can use it to add a third type of embedding to each input token in the sequence
                (the previous two being the word and position embeddings).
                The input, position and token_type embeddings are summed inside the Transformer before the first
                self-attention block.
            `lm_labels`: optional language modeling labels: ``torch.LongTensor`` of shape [batch_size, num_choices, sequence_length]
                with indices selected in [-1, 0, ..., total_tokens_embeddings]. All labels set to -1 are ignored (masked), the loss
                is only computed for the labels set in [0, ..., total_tokens_embeddings]
            `multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
                with indices selected in [0, ..., num_choices].
            `head_mask`: an optional ``torch.Tensor`` of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
                It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.

        Returns:
            if ``lm_labels`` and ``multiple_choice_labels`` are not ``None``, outputs a tuple of losses with the
            language modeling loss and the multiple choice loss. Otherwise, returns a
            ``tuple(lm_logits, multiple_choice_logits)``.

                ``lm_logits`` are the language modeling logits as a ``torch.FloatTensor`` of size
                [batch_size, num_choices, sequence_length, total_tokens_embeddings]

                ``multiple_choice_logits``: the multiple choice logits as a ``torch.FloatTensor`` of
                size [batch_size, num_choices]

        Example::

            # Already been converted into BPE token ids
            input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]]])  # (bsz, number of choice, seq length)
            mc_token_ids = torch.LongTensor([[2], [1]]) # (bsz, number of choice)

            lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids)
            # or
            lm_logits, multiple_choice_logits = model.forward(input_ids, mc_token_ids)
        """
thomwolf's avatar
thomwolf committed
750
751
        transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
        hidden_states = transformer_outputs[0]
752

thomwolf's avatar
thomwolf committed
753
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
754
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
755

756
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
757
758
759
760
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
                            mc_labels.view(-1))
761
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
762
        if lm_labels is not None:
thomwolf's avatar
thomwolf committed
763
764
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
thomwolf's avatar
thomwolf committed
765
            loss_fct = CrossEntropyLoss(ignore_index=-1)
thomwolf's avatar
thomwolf committed
766
767
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1))
768
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
769
770

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)