modeling_openai.py 31.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch OpenAI GPT model."""

18

thomwolf's avatar
thomwolf committed
19
import json
thomwolf's avatar
thomwolf committed
20
import logging
21
22
import math
import os
thomwolf's avatar
thomwolf committed
23
24
25

import torch
import torch.nn as nn
thomwolf's avatar
thomwolf committed
26
from torch.nn import CrossEntropyLoss
thomwolf's avatar
thomwolf committed
27

28
from .configuration_openai import OpenAIGPTConfig
Lysandre's avatar
Lysandre committed
29
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
Aymeric Augustin's avatar
Aymeric Augustin committed
30
31
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer

thomwolf's avatar
thomwolf committed
32

thomwolf's avatar
thomwolf committed
33
34
logger = logging.getLogger(__name__)

35
36
37
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"
}
38

39

40
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
41
42
    """ Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
    """
43
44
    import re
    import numpy as np
45

46
    if ".ckpt" in openai_checkpoint_folder_path:
47
48
49
50
        openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)

    logger.info("Loading weights from {}".format(openai_checkpoint_folder_path))

51
    with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle:
52
        names = json.load(names_handle)
53
    with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle:
54
        shapes = json.load(shapes_handle)
55
    offsets = np.cumsum([np.prod(shape) for shape in shapes])
56
    init_params = [np.load(openai_checkpoint_folder_path + "/params_{}.npy".format(n)) for n in range(10)]
57
58
59
    init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
    init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]

thomwolf's avatar
thomwolf committed
60
    # This was used when we had a single embedding matrix for positions and tokens
61
62
    # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
    # del init_params[1]
63
64
65
    init_params = [arr.squeeze() for arr in init_params]

    try:
66
67
        assert model.tokens_embed.weight.shape == init_params[1].shape
        assert model.positions_embed.weight.shape == init_params[0].shape
68
    except AssertionError as e:
69
70
        e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
        e.args += (model.positions_embed.weight.shape, init_params[0].shape)
71
72
        raise

73
74
    model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
    model.positions_embed.weight.data = torch.from_numpy(init_params[0])
75
    names.pop(0)
76
77
    # Pop position and token embedding arrays
    init_params.pop(0)
78
79
    init_params.pop(0)

80
    for name, array in zip(names, init_params):  # names[1:n_transfer], init_params[1:n_transfer]):
81
82
83
        name = name[6:]  # skip "model/"
        assert name[-2:] == ":0"
        name = name[:-2]
84
        name = name.split("/")
85
86
        pointer = model
        for m_name in name:
87
            if re.fullmatch(r"[A-Za-z]+\d+", m_name):
88
                scope_names = re.split(r"(\d+)", m_name)
89
            else:
90
91
                scope_names = [m_name]
            if scope_names[0] == "g":
92
                pointer = getattr(pointer, "weight")
93
            elif scope_names[0] == "b":
94
                pointer = getattr(pointer, "bias")
95
            elif scope_names[0] == "w":
96
                pointer = getattr(pointer, "weight")
97
            else:
98
99
100
                pointer = getattr(pointer, scope_names[0])
            if len(scope_names) >= 2:
                num = int(scope_names[1])
101
102
103
104
105
106
107
108
109
110
111
                pointer = pointer[num]
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
thomwolf's avatar
thomwolf committed
112
        logger.info("Initialize PyTorch weight {}".format(name))
113
114
115
        pointer.data = torch.from_numpy(array)
    return model

thomwolf's avatar
thomwolf committed
116
117
118
119
120
121
122
123
124

def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


def swish(x):
    return x * torch.sigmoid(x)


125
126
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}

thomwolf's avatar
thomwolf committed
127
128

class Attention(nn.Module):
thomwolf's avatar
thomwolf committed
129
    def __init__(self, nx, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
130
        super().__init__()
thomwolf's avatar
thomwolf committed
131
132
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
133
        assert n_state % config.n_head == 0
thomwolf's avatar
thomwolf committed
134
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
135
        self.n_head = config.n_head
thomwolf's avatar
thomwolf committed
136
137
        self.split_size = n_state
        self.scale = scale
138

thomwolf's avatar
thomwolf committed
139
        self.output_attentions = config.output_attentions
140

141
142
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_proj = Conv1D(n_state, nx)
143
144
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)
145
        self.pruned_heads = set()
thomwolf's avatar
thomwolf committed
146

147
    def prune_heads(self, heads):
thomwolf's avatar
thomwolf committed
148
149
        if len(heads) == 0:
            return
150
        mask = torch.ones(self.n_head, self.split_size // self.n_head)
151
        heads = set(heads) - self.pruned_heads
152
        for head in heads:
153
            head -= sum(1 if h < head else 0 for h in self.pruned_heads)
154
155
156
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
157
        index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
158
159
160
161
162
163
        # Prune conv1d layers
        self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
        self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
        # Update hyper params
        self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
        self.n_head = self.n_head - len(heads)
164
        self.pruned_heads = self.pruned_heads.union(heads)
165

166
    def _attn(self, q, k, v, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
167
168
169
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
thomwolf's avatar
thomwolf committed
170
        # w = w * self.bias + -1e9 * (1 - self.bias)  # TF implem method: mask_attn_weights
thomwolf's avatar
thomwolf committed
171
        # XD: self.b may be larger than w, so we need to crop it
thomwolf's avatar
thomwolf committed
172
        b = self.bias[:, :, : w.size(-2), : w.size(-1)]
173
        w = w * b + -1e4 * (1 - b)
thomwolf's avatar
thomwolf committed
174

175
176
177
178
        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

thomwolf's avatar
thomwolf committed
179
180
        w = nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)
181
182
183
184
185

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

thomwolf's avatar
thomwolf committed
186
        outputs = [torch.matmul(w, v)]
thomwolf's avatar
thomwolf committed
187
        if self.output_attentions:
thomwolf's avatar
thomwolf committed
188
189
            outputs.append(w)
        return outputs
thomwolf's avatar
thomwolf committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

204
    def forward(self, x, attention_mask=None, head_mask=None):
thomwolf's avatar
thomwolf committed
205
206
207
208
209
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
210

211
        attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
thomwolf's avatar
thomwolf committed
212
        a = attn_outputs[0]
213

thomwolf's avatar
thomwolf committed
214
215
216
        a = self.merge_heads(a)
        a = self.c_proj(a)
        a = self.resid_dropout(a)
thomwolf's avatar
thomwolf committed
217
218
219

        outputs = [a] + attn_outputs[1:]
        return outputs  # a, (attentions)
thomwolf's avatar
thomwolf committed
220
221
222


class MLP(nn.Module):
223
    def __init__(self, n_state, config):  # in MLP: n_state=3072 (4 * n_embd)
Julien Chaumond's avatar
Julien Chaumond committed
224
        super().__init__()
225
        nx = config.n_embd
226
227
        self.c_fc = Conv1D(n_state, nx)
        self.c_proj = Conv1D(nx, n_state)
228
229
        self.act = ACT_FNS[config.afn]
        self.dropout = nn.Dropout(config.resid_pdrop)
thomwolf's avatar
thomwolf committed
230
231
232
233
234
235
236
237

    def forward(self, x):
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)


class Block(nn.Module):
thomwolf's avatar
thomwolf committed
238
    def __init__(self, n_ctx, config, scale=False):
Julien Chaumond's avatar
Julien Chaumond committed
239
        super().__init__()
240
        nx = config.n_embd
thomwolf's avatar
thomwolf committed
241
        self.attn = Attention(nx, n_ctx, config, scale)
242
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
243
        self.mlp = MLP(4 * nx, config)
244
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
245

246
247
    def forward(self, x, attention_mask=None, head_mask=None):
        attn_outputs = self.attn(x, attention_mask=attention_mask, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
248
249
        a = attn_outputs[0]

thomwolf's avatar
thomwolf committed
250
251
252
        n = self.ln_1(x + a)
        m = self.mlp(n)
        h = self.ln_2(n + m)
thomwolf's avatar
thomwolf committed
253
254
255

        outputs = [h] + attn_outputs[1:]
        return outputs
thomwolf's avatar
thomwolf committed
256
257


258
class OpenAIGPTPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
259
    """ An abstract class to handle weights initialization and
260
        a simple interface for downloading and loading pretrained models.
thomwolf's avatar
thomwolf committed
261
    """
262

263
    config_class = OpenAIGPTConfig
264
    pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
265
266
    load_tf_weights = load_tf_weights_in_openai_gpt
    base_model_prefix = "transformer"
267

268
    def _init_weights(self, module):
thomwolf's avatar
thomwolf committed
269
270
        """ Initialize the weights.
        """
271
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
thomwolf's avatar
thomwolf committed
272
273
274
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
275
276
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
                module.bias.data.zero_()
277
        elif isinstance(module, nn.LayerNorm):
thomwolf's avatar
thomwolf committed
278
279
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
thomwolf's avatar
thomwolf committed
280
281


Lysandre's avatar
Lysandre committed
282
OPENAI_GPT_START_DOCSTRING = r"""    
thomwolf's avatar
thomwolf committed
283
284
285
286
287
288
289
290
291
292
    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.

    .. _`Improving Language Understanding by Generative Pre-Training`:
        https://openai.com/blog/language-unsupervised/

    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module

    Parameters:
293
        config (:class:`~transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model.
294
            Initializing with a config file does not load the weights associated with the model, only the configuration.
295
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
296
297
"""

Lysandre's avatar
Lysandre committed
298
299
300
301
302
303
OPENAI_GPT_INPUTS_DOCSTRING = r"""    
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. 
            
            Indices can be obtained using :class:`transformers.OpenAIGPTTokenizer`.
304
            See :func:`transformers.PreTrainedTokenizer.encode` and
Lysandre's avatar
Lysandre committed
305
306
307
308
            :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
            
            `What are input IDs? <../glossary.html#input-ids>`__
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
309
310
311
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Lysandre's avatar
Lysandre committed
312
313
314
315
316
317
318
319
320
            
            `What are attention masks? <../glossary.html#attention-mask>`__
        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 
            Segment token indices to indicate first and second portions of the inputs.
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
            corresponds to a `sentence B` token
            
            `What are token type IDs? <../glossary.html#token-type-ids>`_
        position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
321
322
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
Lysandre's avatar
Lysandre committed
323
324
325
            
            `What are position IDs? <../glossary.html#position-ids>`_
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
thomwolf's avatar
thomwolf committed
326
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
327
            Mask values selected in ``[0, 1]``:
Lysandre's avatar
Lysandre committed
328
329
330
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
        input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
331
332
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
thomwolf's avatar
thomwolf committed
333
334
"""

335
336
337
338
339

@add_start_docstrings(
    "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
    OPENAI_GPT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
340
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
341

thomwolf's avatar
thomwolf committed
342
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
343
        super().__init__(config)
thomwolf's avatar
thomwolf committed
344
345
346
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

thomwolf's avatar
thomwolf committed
347
        self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
348
        self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
349
        self.drop = nn.Dropout(config.embd_pdrop)
350
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
thomwolf's avatar
thomwolf committed
351

352
        self.init_weights()
thomwolf's avatar
thomwolf committed
353

thomwolf's avatar
thomwolf committed
354
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
355
        return self.tokens_embed
thomwolf's avatar
thomwolf committed
356

thomwolf's avatar
thomwolf committed
357
    def set_input_embeddings(self, new_embeddings):
358
359
        self.tokens_embed = new_embeddings

thomwolf's avatar
thomwolf committed
360
    def _prune_heads(self, heads_to_prune):
361
362
363
364
365
366
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

Lysandre's avatar
Lysandre committed
367
    @add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
368
369
370
371
372
373
374
375
376
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
    ):
Lysandre's avatar
Lysandre committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        r"""
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (config) and inputs:
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the last layer of the model.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
        model = OpenAIGPTModel.from_pretrained('openai-gpt')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

        """
403
404
405
406
407
408
409
410
411
412
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

thomwolf's avatar
thomwolf committed
413
        if position_ids is None:
414
415
416
417
            # Code is different from when we had a single embedding matrice from position and token embeddings
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(input_shape[-1], dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
thomwolf's avatar
thomwolf committed
418

419
420
421
422
423
424
425
426
427
428
429
430
431
432
        # Attention mask.
        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
433
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
434
435
            attention_mask = (1.0 - attention_mask) * -10000.0

436
        # Prepare head mask if needed
thomwolf's avatar
thomwolf committed
437
        # 1.0 in head_mask indicate we keep the head
438
        # attention_probs has shape bsz x n_heads x N x N
439
        # head_mask has shape n_layer x batch x n_heads x N x N
440
441
        if head_mask is not None:
            if head_mask.dim() == 1:
442
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
443
                head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
444
            elif head_mask.dim() == 2:
445
446
447
448
449
450
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
451
452
        else:
            head_mask = [None] * self.config.n_layer
453

454
455
        if inputs_embeds is None:
            inputs_embeds = self.tokens_embed(input_ids)
456
        position_embeds = self.positions_embed(position_ids)
thomwolf's avatar
thomwolf committed
457
458
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
459
            token_type_embeds = self.tokens_embed(token_type_ids)
thomwolf's avatar
thomwolf committed
460
461
462
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
463
464
        hidden_states = self.drop(hidden_states)

465
466
        output_shape = input_shape + (hidden_states.size(-1),)

467
468
        all_attentions = ()
        all_hidden_states = ()
469
        for i, block in enumerate(self.h):
thomwolf's avatar
thomwolf committed
470
            if self.output_hidden_states:
471
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
472

473
            outputs = block(hidden_states, attention_mask, head_mask[i])
thomwolf's avatar
thomwolf committed
474
            hidden_states = outputs[0]
thomwolf's avatar
thomwolf committed
475
            if self.output_attentions:
476
                all_attentions = all_attentions + (outputs[1],)
thomwolf's avatar
thomwolf committed
477
478
479

        # Add last layer
        if self.output_hidden_states:
480
            all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
481

482
        outputs = (hidden_states.view(*output_shape),)
thomwolf's avatar
thomwolf committed
483
        if self.output_hidden_states:
484
            outputs = outputs + (all_hidden_states,)
thomwolf's avatar
thomwolf committed
485
        if self.output_attentions:
486
            outputs = outputs + (all_attentions,)
thomwolf's avatar
thomwolf committed
487
        return outputs  # last hidden state, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
488

489

490
491
@add_start_docstrings(
    """OpenAI GPT Model transformer with a language modeling head on top
Lysandre's avatar
Lysandre committed
492
    (linear layer with weights tied to the input embeddings). """,
493
494
    OPENAI_GPT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
495
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
496

thomwolf's avatar
thomwolf committed
497
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
498
        super().__init__(config)
thomwolf's avatar
thomwolf committed
499
        self.transformer = OpenAIGPTModel(config)
thomwolf's avatar
thomwolf committed
500
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
501

502
        self.init_weights()
503

thomwolf's avatar
thomwolf committed
504
    def get_output_embeddings(self):
505
        return self.lm_head
thomwolf's avatar
thomwolf committed
506

Lysandre's avatar
Lysandre committed
507
    @add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
508
509
510
511
512
513
514
515
516
517
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
Lysandre's avatar
Lysandre committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
            All labels set to ``-100`` are ignored (masked), the loss is only
            computed for labels in ``[0, ..., config.vocab_size]``

    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.OpenAIGPTConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
            Language modeling loss.
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
        model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]

    """
557
558
559
560
561
562
563
564
        transformer_outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
thomwolf's avatar
thomwolf committed
565
        hidden_states = transformer_outputs[0]
thomwolf's avatar
thomwolf committed
566
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
567

568
        outputs = (lm_logits,) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
569
        if labels is not None:
570
            # Shift so that tokens < n predict n
thomwolf's avatar
thomwolf committed
571
            shift_logits = lm_logits[..., :-1, :].contiguous()
thomwolf's avatar
thomwolf committed
572
            shift_labels = labels[..., 1:].contiguous()
Catalin Voss's avatar
Catalin Voss committed
573
            # Flatten the tokens
LysandreJik's avatar
LysandreJik committed
574
            loss_fct = CrossEntropyLoss()
575
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
576
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
577
578

        return outputs  # (loss), lm_logits, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
579

580

581
582
@add_start_docstrings(
    """OpenAI GPT Model transformer with a language modeling and a multiple-choice classification
Lysandre's avatar
Lysandre committed
583
584
585
    head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
    The language modeling head has its weights tied to the input embeddings,
    the classification head takes as input the input of a specified classification token index in the input sequence).
586
587
588
""",
    OPENAI_GPT_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
589
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
Lysandre's avatar
Lysandre committed
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618

    def __init__(self, config):
        super().__init__(config)

        config.num_labels = 1
        self.transformer = OpenAIGPTModel(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.multiple_choice_head = SequenceSummary(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    @add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        mc_token_ids=None,
        lm_labels=None,
        mc_labels=None,
    ):
        r"""
        mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
thomwolf's avatar
thomwolf committed
619
620
            Index of the classification token in each input sequence.
            Selected in the range ``[0, input_ids.size(-1) - 1[``.
Lysandre's avatar
Lysandre committed
621
        lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
622
623
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Lysandre's avatar
Lysandre committed
624
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
LysandreJik's avatar
LysandreJik committed
625
            All labels set to ``-100`` are ignored (masked), the loss is only
thomwolf's avatar
thomwolf committed
626
            computed for labels in ``[0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
627
        mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
thomwolf's avatar
thomwolf committed
628
629
630
            Labels for computing the multiple choice classification loss.
            Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
            of the input tensors. (see `input_ids` above)
631

Lysandre's avatar
Lysandre committed
632
633
634
    Return:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:obj:`~transformers.OpenAIGPTConfig`) and inputs:
        lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``lm_labels`` is provided):
thomwolf's avatar
thomwolf committed
635
            Language modeling loss.
Lysandre's avatar
Lysandre committed
636
        mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`multiple_choice_labels` is provided):
thomwolf's avatar
thomwolf committed
637
            Multiple choice classification loss.
Lysandre's avatar
Lysandre committed
638
        lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
thomwolf's avatar
thomwolf committed
639
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Lysandre's avatar
Lysandre committed
640
641
642
643
644
645
646
647
648
649
        mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
            Contains pre-computed hidden-states (key and values in the attention blocks).
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
            should not be passed as input ids as they have already been computed.
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

thomwolf's avatar
thomwolf committed
650
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
Lysandre's avatar
Lysandre committed
651
652
653
654
655
656
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
thomwolf's avatar
thomwolf committed
657
658
659

    Examples::

wangfei's avatar
wangfei committed
660
661
        tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
        model = OpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt')
thomwolf's avatar
thomwolf committed
662
        tokenizer.add_special_tokens({'cls_token': '[CLS]'})  # Add a [CLS] to the vocabulary (we should train it also!)
thomwolf's avatar
thomwolf committed
663
664
        model.resize_token_embeddings(len(tokenizer))

thomwolf's avatar
thomwolf committed
665
        choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
wangfei's avatar
wangfei committed
666
        input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0)  # Batch size 1, 2 choices
thomwolf's avatar
thomwolf committed
667
668
        mc_token_ids = torch.tensor([input_ids.size(-1)-1, input_ids.size(-1)-1]).unsqueeze(0)  # Batch size 1

thomwolf's avatar
thomwolf committed
669
        outputs = model(input_ids, mc_token_ids=mc_token_ids)
wangfei's avatar
wangfei committed
670
        lm_prediction_scores, mc_prediction_scores = outputs[:2]
671

672
    """
673
674
675
676
677
678
679
680
        transformer_outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
thomwolf's avatar
thomwolf committed
681
        hidden_states = transformer_outputs[0]
682

thomwolf's avatar
thomwolf committed
683
        lm_logits = self.lm_head(hidden_states)
thomwolf's avatar
thomwolf committed
684
        mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
thomwolf's avatar
thomwolf committed
685

686
        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
687
688
        if mc_labels is not None:
            loss_fct = CrossEntropyLoss()
689
            loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
690
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
691
        if lm_labels is not None:
thomwolf's avatar
thomwolf committed
692
693
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
LysandreJik's avatar
LysandreJik committed
694
            loss_fct = CrossEntropyLoss()
695
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
696
            outputs = (loss,) + outputs
thomwolf's avatar
thomwolf committed
697
698

        return outputs  # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)