"vscode:/vscode.git/clone" did not exist on "832ed6f8e6cf958887f31e14908b9c0ab59d918c"
modeling_openai.py 30.2 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT model."""

18
19
from __future__ import absolute_import, division, print_function, unicode_literals

20
import collections
thomwolf's avatar
thomwolf committed
21
import json
thomwolf's avatar
thomwolf committed
22
import logging
23
24
import math
import os
thomwolf's avatar
thomwolf committed
25
26
import sys
from io import open
thomwolf's avatar
thomwolf committed
27
28
29

import torch
import torch.nn as nn
thomwolf's avatar
thomwolf committed
30
from torch.nn import CrossEntropyLoss
thomwolf's avatar
thomwolf committed
31
32
from torch.nn.parameter import Parameter

33
34
35
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_start_docstrings
thomwolf's avatar
thomwolf committed
36

thomwolf's avatar
thomwolf committed
37
38
logger = logging.getLogger(__name__)

39
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"}
40

41

42
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
43
44
    """ Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
    """
45
46
    import re
    import numpy as np
47
48
49
50
51
52

    if '.ckpt' in openai_checkpoint_folder_path:
        openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)

    logger.info("Loading weights from {}".format(openai_checkpoint_folder_path))

53
54
55
56
57
58
59
    names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
    shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
    offsets = np.cumsum([np.prod(shape) for shape in shapes])
    init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
    init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
    init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]

thomwolf's avatar
thomwolf committed
60
    # This was used when we had a single embedding matrix for positions and tokens
61
62
    # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
    # del init_params[1]
63
64
65
    init_params = [arr.squeeze() for arr in init_params]

    try:
66
67
        assert model.tokens_embed.weight.shape == init_params[1].shape
        assert model.positions_embed.weight.shape == init_params[0].shape
68
    except AssertionError as e:
69
70
        e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
        e.args += (model.positions_embed.weight.shape, init_params[0].shape)
71
72
        raise

73
74
    model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
    model.positions_embed.weight.data = torch.from_numpy(init_params[0])
75
    names.pop(0)
76
77
    # Pop position and token embedding arrays
    init_params.pop(0)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    init_params.pop(0)

    for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
        name = name[6:]  # skip "model/"
        assert name[-2:] == ":0"
        name = name[:-2]
        name = name.split('/')
        pointer = model
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+\d+', m_name):
                l = re.split(r'(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'g':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'b':
                pointer = getattr(pointer, 'bias')
            elif l[0] == 'w':
                pointer = getattr(pointer, 'weight')
            else:
                pointer = getattr(pointer, l[0])
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
112
        logger.info("Initialize PyTorch weight {}".format(name))
113
114
115
        pointer.data = torch.from_numpy(array)
    return model

thomwolf's avatar
thomwolf committed
116
117
118
119
120
121
122
123
124

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


def swish(x):
    return x * torch.sigmoid(x)


125
126
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}

thomwolf's avatar
thomwolf committed
127
128

class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
129
    def __init__(self, nx, n_ctx, config, scale=False):
thomwolf's avatar
thomwolf committed
130
131
132
        super(Attention, self).__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
133
        assert n_state % config.n_head == 0
thomwolf's avatar
thomwolf committed
134
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
135
        self.n_head = config.n_head
thomwolf's avatar
thomwolf committed
136
137
        self.split_size = n_state
        self.scale = scale
138

thomwolf's avatar
thomwolf committed
139
        self.output_attentions = config.output_attentions
140

141
142
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
143
144
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
145
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
146

147
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
148
149
        if len(heads) == 0:
            return
150
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
151
        heads = set(heads) - self.pruned_heads
152
        for head in heads:
153
            head -= sum(1 if h < head else 0 for h in self.pruned_heads)
154
155
156
157
158
159
160
161
162
163
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
164
        self.pruned_heads = self.pruned_heads.union(heads)
165

166
    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
167
168
169
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
thomwolf's avatar
thomwolf committed
170
        # w = w * self.bias + -1e9 * (1 - self.bias)  # TF implem method: mask_attn_weights
thomwolf's avatar
thomwolf committed
171
        # XD: self.b may be larger than w, so we need to crop it
thomwolf's avatar
thomwolf committed
172
        b = self.bias[:, :, : w.size(-2), : w.size(-1)]
thomwolf's avatar
thomwolf committed
173
174
        w = w * b + -1e9 * (1 - b)

175
176
177
178
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

thomwolf's avatar
thomwolf committed
179
180
        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)
181
182
183
184
185

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
186
        outputs = [torch.matmul(w, v)]
thomwolf's avatar
thomwolf committed
187
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
188
189
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

204
    def forward(self, x, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
205
206
207
208
209
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
210

211
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
thomwolf's avatar
thomwolf committed
212
        a = attn_outputs[0]
213

thomwolf's avatar
thomwolf committed
214
215
216
        a = self.merge_heads(a)
        a = self.c_proj(a)
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
217
218
219

        outputs = [a] + attn_outputs[1:]
        return outputs  # a, (attentions)
thomwolf's avatar
thomwolf committed
220
221
222


class MLP(nn.Module):
223
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
thomwolf's avatar
thomwolf committed
224
        super(MLP, self).__init__()
225
        nx = config.n_embd
226
227
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
228
229
        self.act = ACT_FNS[config.afn]
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
230
231
232
233
234
235
236
237

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
238
    def __init__(self, n_ctx, config, scale=False):
thomwolf's avatar
thomwolf committed
239
        super(Block, self).__init__()
240
        nx = config.n_embd
thomwolf's avatar
thomwolf committed
241
        self.attn = Attention(nx, n_ctx, config, scale)
242
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
243
        self.mlp = MLP(4 * nx, config)
244
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
245

246
247
    def forward(self, x, attention_mask=None, head_mask=None):
        attn_outputs = self.attn(x, attention_mask=attention_mask, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
248
249
        a = attn_outputs[0]

thomwolf's avatar
thomwolf committed
250
251
252
        n = self.ln_1(x + a)
        m = self.mlp(n)
        h = self.ln_2(n + m)
thomwolf's avatar
thomwolf committed
253
254
255

        outputs = [h] + attn_outputs[1:]
        return outputs
thomwolf's avatar
thomwolf committed
256
257


258
class OpenAIGPTPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
259
260
261
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
262
    config_class = OpenAIGPTConfig
263
    pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
264
265
    load_tf_weights = load_tf_weights_in_openai_gpt
    base_model_prefix = "transformer"
266

267
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
268
269
        """ Initialize the weights.
        """
270
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
271
272
273
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
274
275
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
276
        elif isinstance(module, nn.LayerNorm):
thomwolf's avatar
thomwolf committed
277
278
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
thomwolf's avatar
thomwolf committed
279
280


thomwolf's avatar
thomwolf committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
OPENAI_GPT_START_DOCSTRING = r"""    OpenAI GPT model was proposed in
    `Improving Language Understanding by Generative Pre-Training`_
    by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
    It's a causal (unidirectional) transformer pre-trained using language modeling on a large
    corpus will long range dependencies, the Toronto Book Corpus.

    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.

    .. _`Improving Language Understanding by Generative Pre-Training`:
        https://openai.com/blog/language-unsupervised/

    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module

    Parameters:
thomwolf's avatar
thomwolf committed
297
        config (:class:`~pytorch_transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model.
298
299
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
300
301
"""

thomwolf's avatar
thomwolf committed
302
OPENAI_GPT_INPUTS_DOCSTRING = r"""    Inputs:
thomwolf's avatar
thomwolf committed
303
304
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
thomwolf's avatar
thomwolf committed
305
306
            GPT is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.
thomwolf's avatar
thomwolf committed
307
308
309
            Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
310
311
312
313
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
thomwolf's avatar
thomwolf committed
314
315
316
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            A parallel sequence of tokens (can be used to indicate various portions of the inputs).
            The embeddings from these tokens will be summed with the respective token embeddings.
317
            Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices)
318
319
320
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
thomwolf's avatar
thomwolf committed
321
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
thomwolf's avatar
thomwolf committed
322
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
323
            Mask values selected in ``[0, 1]``:
thomwolf's avatar
thomwolf committed
324
325
326
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

Julien Chaumond's avatar
Julien Chaumond committed
327
@add_start_docstrings("The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
thomwolf's avatar
thomwolf committed
328
                      OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
329
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
thomwolf's avatar
thomwolf committed
330
331
332
333
334
335
336
337
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
338
339
340
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
341
342
343

    Examples::

wangfei's avatar
wangfei committed
344
345
346
347
348
        tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
        model = OpenAIGPTModel.from_pretrained('openai-gpt')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
349
350

    """
thomwolf's avatar
thomwolf committed
351
    def __init__(self, config):
352
        super(OpenAIGPTModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
353
354
355
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

thomwolf's avatar
thomwolf committed
356
        self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
357
        self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
358
        self.drop = nn.Dropout(config.embd_pdrop)
359
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
thomwolf's avatar
thomwolf committed
360

361
        self.init_weights()
thomwolf's avatar
thomwolf committed
362

thomwolf's avatar
thomwolf committed
363
364
    def _resize_token_embeddings(self, new_num_tokens):
        self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
thomwolf's avatar
thomwolf committed
365
        return self.tokens_embed
thomwolf's avatar
thomwolf committed
366

thomwolf's avatar
thomwolf committed
367
    def _prune_heads(self, heads_to_prune):
368
369
370
371
372
373
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

374
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
thomwolf's avatar
thomwolf committed
375
        if position_ids is None:
376
377
378
379
380
            # This was used when we had a single embedding matrice from position and token embeddings
            # start = self.config.vocab_size + self.config.n_special
            # end = start + input_ids.size(-1)
            # position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
            position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
thomwolf's avatar
thomwolf committed
381
382
            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        # Attention mask.
        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * -10000.0

400
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
401
        # 1.0 in head_mask indicate we keep the head
402
        # attention_probs has shape bsz x n_heads x N x N
403
        # head_mask has shape n_layer x batch x n_heads x N x N
404
405
        if head_mask is not None:
            if head_mask.dim() == 1:
406
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
407
                head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
408
            elif head_mask.dim() == 2:
409
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
410
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
411
412
        else:
            head_mask = [None] * self.config.n_layer
413

thomwolf's avatar
thomwolf committed
414
415
416
417
        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_ids.size(-1))
        position_ids = position_ids.view(-1, position_ids.size(-1))

418
419
        inputs_embeds = self.tokens_embed(input_ids)
        position_embeds = self.positions_embed(position_ids)
thomwolf's avatar
thomwolf committed
420
421
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
422
            token_type_embeds = self.tokens_embed(token_type_ids)
thomwolf's avatar
thomwolf committed
423
424
425
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
426
427
        hidden_states = self.drop(hidden_states)

428
429
        output_shape = input_shape + (hidden_states.size(-1),)

430
431
        all_attentions = ()
        all_hidden_states = ()
432
        for i, block in enumerate(self.h):
thomwolf's avatar
thomwolf committed
433
            if self.output_hidden_states:
434
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
435

436
            outputs = block(hidden_states, attention_mask, head_mask[i])
thomwolf's avatar
thomwolf committed
437
            hidden_states = outputs[0]
thomwolf's avatar
thomwolf committed
438
            if self.output_attentions:
439
                all_attentions = all_attentions + (outputs[1],)
thomwolf's avatar
thomwolf committed
440
441
442

        # Add last layer
        if self.output_hidden_states:
443
            all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
444

445
        outputs = (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
446
        if self.output_hidden_states:
447
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
448
        if self.output_attentions:
449
            outputs = outputs + (all_attentions,)
thomwolf's avatar
thomwolf committed
450
        return outputs  # last hidden state, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
451

452

thomwolf's avatar
thomwolf committed
453
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling head on top
thomwolf's avatar
thomwolf committed
454
(linear layer with weights tied to the input embeddings). """, OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
455
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
thomwolf's avatar
thomwolf committed
456
    r"""
thomwolf's avatar
thomwolf committed
457
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
thomwolf's avatar
thomwolf committed
458
            Labels for language modeling.
459
            Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
thomwolf's avatar
thomwolf committed
460
461
462
463
464
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
            All labels set to ``-1`` are ignored (masked), the loss is only
            computed for labels in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
thomwolf's avatar
thomwolf committed
465
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
thomwolf's avatar
thomwolf committed
466
467
468
469
470
471
472
            Language modeling loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
473
474
475
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
476
477
478

    Examples::

wangfei's avatar
wangfei committed
479
480
481
482
483
        tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
        model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]
484
485

    """
thomwolf's avatar
thomwolf committed
486
    def __init__(self, config):
487
        super(OpenAIGPTLMHeadModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
488
        self.transformer = OpenAIGPTModel(config)
thomwolf's avatar
thomwolf committed
489
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
490

491
        self.init_weights()
thomwolf's avatar
thomwolf committed
492
        self.tie_weights()
493

thomwolf's avatar
thomwolf committed
494
495
496
    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
497
        """
thomwolf's avatar
thomwolf committed
498
499
        self._tie_or_clone_weights(self.lm_head,
                                   self.transformer.tokens_embed)
thomwolf's avatar
thomwolf committed
500

501
502
503
504
505
506
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                labels=None):
        transformer_outputs = self.transformer(input_ids,
                                               attention_mask=attention_mask,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
thomwolf's avatar
thomwolf committed
507
                                               head_mask=head_mask)
thomwolf's avatar
thomwolf committed
508
        hidden_states = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
509
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
510

511
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
512
        if labels is not None:
513
            # Shift so that tokens < n predict n
thomwolf's avatar
thomwolf committed
514
            shift_logits = lm_logits[..., :-1, :].contiguous()
thomwolf's avatar
thomwolf committed
515
            shift_labels = labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
516
            # Flatten the tokens
thomwolf's avatar
thomwolf committed
517
            loss_fct = CrossEntropyLoss(ignore_index=-1)
518
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
519
                            shift_labels.view(-1))
520
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
521
522

        return outputs  # (loss), lm_logits, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
523

524

thomwolf's avatar
thomwolf committed
525
526
527
@add_start_docstrings("""OpenAI GPT Model transformer with a language modeling and a multiple-choice classification
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
The language modeling head has its weights tied to the input embeddings,
Julien Chaumond's avatar
Julien Chaumond committed
528
the classification head takes as input the input of a specified classification token index in the input sequence).
529
""", OPENAI_GPT_START_DOCSTRING, OPENAI_GPT_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
530
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
531
    r"""
thomwolf's avatar
thomwolf committed
532
        **mc_token_ids**: (`optional`, default to index of the last token of the input) ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
thomwolf's avatar
thomwolf committed
533
534
535
536
537
538
539
540
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
        **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
            All labels set to ``-1`` are ignored (masked), the loss is only
            computed for labels in ``[0, ..., config.vocab_size]``
541
        **mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
thomwolf's avatar
thomwolf committed
542
543
544
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
545

thomwolf's avatar
thomwolf committed
546
547
            `multiple_choice_labels`: optional multiple choice labels: ``torch.LongTensor`` of shape [batch_size]
                with indices selected in [0, ..., num_choices].
548

thomwolf's avatar
thomwolf committed
549
550
551
552
553
554
555
556
557
558
559
560
561
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Language modeling loss.
        **mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Multiple choice classification loss.
        **lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
            Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
562
563
564
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
565
566
567

    Examples::

wangfei's avatar
wangfei committed
568
569
        tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
        model = OpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt')
thomwolf's avatar
thomwolf committed
570
571
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})  # Add a [CLS] to the vocabulary (we should train it also!)
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
wangfei's avatar
wangfei committed
572
        input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0)  # Batch size 1, 2 choices
thomwolf's avatar
thomwolf committed
573
        mc_token_ids = torch.tensor([input_ids.size(-1), input_ids.size(-1)]).unsqueeze(0)  # Batch size 1
thomwolf's avatar
thomwolf committed
574
        outputs = model(input_ids, mc_token_ids=mc_token_ids)
wangfei's avatar
wangfei committed
575
        lm_prediction_scores, mc_prediction_scores = outputs[:2]
576

577
    """
thomwolf's avatar
thomwolf committed
578
    def __init__(self, config):
579
        super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
580

thomwolf's avatar
thomwolf committed
581
        self.transformer = OpenAIGPTModel(config)
thomwolf's avatar
thomwolf committed
582
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
583
584
        self.multiple_choice_head = SequenceSummary(config)

585
        self.init_weights()
thomwolf's avatar
thomwolf committed
586
        self.tie_weights()
thomwolf's avatar
thomwolf committed
587

thomwolf's avatar
thomwolf committed
588
589
590
    def tie_weights(self):
        """ Make sure we are sharing the input and output embeddings.
            Export to TorchScript can't handle parameter sharing so we are cloning them instead.
591
        """
thomwolf's avatar
thomwolf committed
592
593
        self._tie_or_clone_weights(self.lm_head,
                                   self.transformer.tokens_embed)
thomwolf's avatar
thomwolf committed
594

595
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
thomwolf's avatar
thomwolf committed
596
                mc_token_ids=None, lm_labels=None, mc_labels=None):
597
598
599
600
        transformer_outputs = self.transformer(input_ids,
                                               attention_mask=attention_mask,
                                               token_type_ids=token_type_ids,
                                               position_ids=position_ids,
thomwolf's avatar
thomwolf committed
601
                                               head_mask=head_mask)
thomwolf's avatar
thomwolf committed
602
        hidden_states = transformer_outputs[0]
603

thomwolf's avatar
thomwolf committed
604
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
605
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
606

607
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
608
609
610
611
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
                            mc_labels.view(-1))
612
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
613
        if lm_labels is not None:
thomwolf's avatar
thomwolf committed
614
615
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
thomwolf's avatar
thomwolf committed
616
            loss_fct = CrossEntropyLoss(ignore_index=-1)
thomwolf's avatar
thomwolf committed
617
618
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1))
619
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
620
621

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)