modeling_ctrl.py 24 KB
Newer Older
keskarnitish's avatar
keskarnitish committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 Salesforce and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
16
""" PyTorch CTRL model."""
keskarnitish's avatar
keskarnitish committed
17
18
19


import logging
Aymeric Augustin's avatar
Aymeric Augustin committed
20

keskarnitish's avatar
keskarnitish committed
21
22
23
24
25
26
27
import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

from .configuration_ctrl import CTRLConfig
from .file_utils import add_start_docstrings
28
from .modeling_utils import Conv1D, PreTrainedModel
Aymeric Augustin's avatar
Aymeric Augustin committed
29

keskarnitish's avatar
keskarnitish committed
30
31
32
33
34
35
36

logger = logging.getLogger(__name__)

CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/seqlen256_v1.bin"}


def angle_defn(pos, i, d_model_size):
37
    angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / d_model_size)
thomwolf's avatar
thomwolf committed
38
    return pos * angle_rates
keskarnitish's avatar
keskarnitish committed
39

40

keskarnitish's avatar
keskarnitish committed
41
def positional_encoding(position, d_model_size, dtype):
thomwolf's avatar
thomwolf committed
42
    # create the sinusoidal pattern for the positional encoding
43
44
45
46
47
    angle_rads = angle_defn(
        torch.arange(position, dtype=dtype).unsqueeze(1),
        torch.arange(d_model_size, dtype=dtype).unsqueeze(0),
        d_model_size,
    )
thomwolf's avatar
thomwolf committed
48
49
50
51

    sines = torch.sin(angle_rads[:, 0::2])
    cosines = torch.cos(angle_rads[:, 1::2])

thomwolf's avatar
thomwolf committed
52
    pos_encoding = torch.cat([sines, cosines], dim=-1)
thomwolf's avatar
thomwolf committed
53
54
    return pos_encoding

55

thomwolf's avatar
thomwolf committed
56
57
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
    # calculate attention
58
    matmul_qk = torch.matmul(q, k.permute(0, 1, 3, 2))
thomwolf's avatar
thomwolf committed
59
60
61
62
63

    dk = k.shape[-1]
    scaled_attention_logits = matmul_qk / np.sqrt(dk)

    if mask is not None:
LysandreJik's avatar
LysandreJik committed
64
        nd, ns = scaled_attention_logits.size(-2), scaled_attention_logits.size(-1)
65
        scaled_attention_logits += mask[ns - nd : ns, :ns] * -1e4
thomwolf's avatar
thomwolf committed
66
67
68
69
70

    if attention_mask is not None:
        # Apply the attention mask
        scaled_attention_logits = scaled_attention_logits + attention_mask

71
    attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
thomwolf's avatar
thomwolf committed
72
73
74
75
76
77
78
79

    # Mask heads if we want to
    if head_mask is not None:
        attention_weights = attention_weights * head_mask

    output = torch.matmul(attention_weights, v)

    return output, attention_weights
keskarnitish's avatar
keskarnitish committed
80
81
82


class MultiHeadAttention(torch.nn.Module):
thomwolf's avatar
thomwolf committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def __init__(self, d_model_size, num_heads, output_attentions=False):
        super(MultiHeadAttention, self).__init__()
        self.output_attentions = output_attentions
        self.num_heads = num_heads
        self.d_model_size = d_model_size

        self.depth = int(d_model_size / self.num_heads)

        self.Wq = torch.nn.Linear(d_model_size, d_model_size)
        self.Wk = torch.nn.Linear(d_model_size, d_model_size)
        self.Wv = torch.nn.Linear(d_model_size, d_model_size)

        self.dense = torch.nn.Linear(d_model_size, d_model_size)

    def split_into_heads(self, x, batch_size):
        x = x.reshape(batch_size, -1, self.num_heads, self.depth)
        return x.permute([0, 2, 1, 3])

    def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None):
        batch_size = q.shape[0]

        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)

        q = self.split_into_heads(q, batch_size)
        k = self.split_into_heads(k, batch_size)
        v = self.split_into_heads(v, batch_size)
        if layer_past is not None:
thomwolf's avatar
thomwolf committed
112
            past_key, past_value = layer_past[0], layer_past[1]
thomwolf's avatar
thomwolf committed
113
            k = torch.cat((past_key, k), dim=-2)
thomwolf's avatar
thomwolf committed
114
            v = torch.cat((past_value, v), dim=-2)
thomwolf's avatar
thomwolf committed
115
        present = torch.stack((k, v))
thomwolf's avatar
thomwolf committed
116

thomwolf's avatar
thomwolf committed
117
        output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
thomwolf's avatar
thomwolf committed
118
119
120
121
122
        scaled_attention = output[0].permute([0, 2, 1, 3])
        attn = output[1]
        original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
        output = self.dense(original_size_attention)

thomwolf's avatar
thomwolf committed
123
124
125
126
        outputs = (output, present)
        if self.output_attentions:
            outputs = outputs + (attn,)
        return outputs
keskarnitish's avatar
keskarnitish committed
127
128
129


def point_wise_feed_forward_network(d_model_size, dff):
130
    return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff), torch.nn.ReLU(), torch.nn.Linear(dff, d_model_size))
keskarnitish's avatar
keskarnitish committed
131
132
133


class EncoderLayer(torch.nn.Module):
thomwolf's avatar
thomwolf committed
134
135
    def __init__(self, d_model_size, num_heads, dff, rate=0.1, output_attentions=False):
        super(EncoderLayer, self).__init__()
keskarnitish's avatar
keskarnitish committed
136

thomwolf's avatar
thomwolf committed
137
138
        self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, output_attentions)
        self.ffn = point_wise_feed_forward_network(d_model_size, dff)
keskarnitish's avatar
keskarnitish committed
139

thomwolf's avatar
thomwolf committed
140
141
142
143
144
        self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
        self.layernorm2 = torch.nn.LayerNorm(d_model_size, eps=1e-6)

        self.dropout1 = torch.nn.Dropout(rate)
        self.dropout2 = torch.nn.Dropout(rate)
keskarnitish's avatar
keskarnitish committed
145

thomwolf's avatar
thomwolf committed
146
147
    def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
        normed = self.layernorm1(x)
148
149
150
        attn_outputs = self.multi_head_attention(
            normed, normed, normed, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
        )
thomwolf's avatar
thomwolf committed
151
        attn_output = attn_outputs[0]
thomwolf's avatar
thomwolf committed
152
153
        attn_output = self.dropout1(attn_output)
        out1 = x + attn_output
keskarnitish's avatar
keskarnitish committed
154

thomwolf's avatar
thomwolf committed
155
156
157
158
159
        out2 = self.layernorm2(out1)
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout2(ffn_output)
        out2 = out1 + ffn_output

thomwolf's avatar
thomwolf committed
160
161
        outputs = (out2,) + attn_outputs[1:]
        return outputs
thomwolf's avatar
thomwolf committed
162
163
164


class CTRLPreTrainedModel(PreTrainedModel):
thomwolf's avatar
thomwolf committed
165
166
167
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
168

thomwolf's avatar
thomwolf committed
169
170
171
172
173
174
    config_class = CTRLConfig
    pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
    base_model_prefix = "transformer"

    def _init_weights(self, module):
        """ Initialize the weights.
keskarnitish's avatar
keskarnitish committed
175
        """
thomwolf's avatar
thomwolf committed
176
177
178
179
180
        if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
keskarnitish's avatar
keskarnitish committed
181
                module.bias.data.zero_()
thomwolf's avatar
thomwolf committed
182
183
184
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
keskarnitish's avatar
keskarnitish committed
185
186


187
CTRL_START_DOCSTRING = r"""    CTRL model was proposed in
LysandreJik's avatar
LysandreJik committed
188
189
190
191
    `CTRL: A Conditional Transformer Language Model for Controllable Generation`_
    by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
    It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
    corpus of ~140 GB of text data with the first token reserved as a control code (such as Links, Books, Wikipedia etc.).
keskarnitish's avatar
keskarnitish committed
192

LysandreJik's avatar
LysandreJik committed
193
194
    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.
keskarnitish's avatar
keskarnitish committed
195

LysandreJik's avatar
LysandreJik committed
196
197
    .. _`CTRL: A Conditional Transformer Language Model for Controllable Generation`:
        https://www.github.com/salesforce/ctrl
keskarnitish's avatar
keskarnitish committed
198

LysandreJik's avatar
LysandreJik committed
199
200
    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module
keskarnitish's avatar
keskarnitish committed
201

LysandreJik's avatar
LysandreJik committed
202
203
204
205
    Parameters:
        config (:class:`~transformers.CTRLConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
keskarnitish's avatar
keskarnitish committed
206
207
208
209
210
211
212
213
214
215
216
217
218
"""

CTRL_INPUTS_DOCSTRING = r"""    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            CTRL is a model with absolute position embeddings so it's usually advised to pad the inputs on
            the right rather than the left.
            Indices can be obtained using :class:`transformers.CTRLTokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
            :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
        **past**:
            list of ``torch.FloatTensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
219
            (see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
220
            should not be passed as input ids as they have already been computed.
keskarnitish's avatar
keskarnitish committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
        **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            A parallel sequence of tokens (can be used to indicate various portions of the inputs).
            The embeddings from these tokens will be summed with the respective token embeddings.
            Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
        **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of positions of each input sequence tokens in the position embeddings.
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
236
237
238
239
        **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
            Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
            than the model's internal embedding lookup matrix.
keskarnitish's avatar
keskarnitish committed
240
241
"""

242
243
244
245
246
247

@add_start_docstrings(
    "The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",
    CTRL_START_DOCSTRING,
    CTRL_INPUTS_DOCSTRING,
)
keskarnitish's avatar
keskarnitish committed
248
249
250
251
252
253
class CTRLModel(CTRLPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
        **past**:
254
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(2, batch_size, num_heads, sequence_length, embed_size_per_head)``:
keskarnitish's avatar
keskarnitish committed
255
            that contains pre-computed hidden-states (key and values in the attention blocks).
256
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
257
            should not be passed as input ids as they have already been computed.
keskarnitish's avatar
keskarnitish committed
258
259
260
261
262
263
264
265
266
267
268
269
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        tokenizer = CTRLTokenizer.from_pretrained('ctrl')
        model = CTRLModel.from_pretrained('ctrl')
270
        input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
keskarnitish's avatar
keskarnitish committed
271
272
273
274
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

    """
275

keskarnitish's avatar
keskarnitish committed
276
277
278
    def __init__(self, config):
        super(CTRLModel, self).__init__(config)
        self.output_hidden_states = config.output_hidden_states
279
280
281
        self.output_attentions = config.output_attentions
        self.output_past = config.output_past

keskarnitish's avatar
keskarnitish committed
282
283
        self.d_model_size = config.n_embd
        self.num_layers = config.n_layer
284

keskarnitish's avatar
keskarnitish committed
285
286
287
288
289
        self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)

        self.w = nn.Embedding(config.vocab_size, config.n_embd)

        self.dropout = nn.Dropout(config.embd_pdrop)
290
291
292
293
294
295
        self.h = nn.ModuleList(
            [
                EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop, config.output_attentions)
                for _ in range(config.n_layer)
            ]
        )
keskarnitish's avatar
keskarnitish committed
296
297
298
299
        self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

        self.init_weights()

thomwolf's avatar
thomwolf committed
300
    def get_input_embeddings(self):
keskarnitish's avatar
keskarnitish committed
301
302
        return self.w

thomwolf's avatar
thomwolf committed
303
    def set_input_embeddings(self, new_embeddings):
304
305
        self.w = new_embeddings

keskarnitish's avatar
keskarnitish committed
306
307
    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
thomwolf's avatar
thomwolf committed
308
                heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
keskarnitish's avatar
keskarnitish committed
309
310
311
312
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

313
314
315
316
317
318
319
320
321
322
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
    ):
323
324
325
326
327
328
329
330
331
332
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

thomwolf's avatar
thomwolf committed
333
334
335
336
337
338
        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
            past_length = past[0][0].size(-2)
        if position_ids is None:
339
340
341
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
thomwolf's avatar
thomwolf committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357

        # Attention mask.
        if attention_mask is not None:
            attention_mask = attention_mask.view(-1, input_shape[-1])
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
358
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
thomwolf's avatar
thomwolf committed
359
360
361
362
363
364
365
366
367
368
369
            attention_mask = (1.0 - attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
370
371
372
373
374
375
                head_mask = (
                    head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
                )  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype
            )  # switch to fload if need + fp16 compatibility
thomwolf's avatar
thomwolf committed
376
377
378
        else:
            head_mask = [None] * self.config.n_layer

379
380
381
382
383
384
385
386
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
            token_type_embeds = self.w(token_type_ids)
            token_type_embeds *= np.sqrt(self.d_model_size)
        else:
            token_type_embeds = 0
        position_ids = position_ids.view(-1, input_shape[-1])

387
388
        if inputs_embeds is None:
            inputs_embeds = self.w(input_ids)
389
        # inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
thomwolf's avatar
thomwolf committed
390
        seq_len = input_shape[-1]
LysandreJik's avatar
LysandreJik committed
391
        mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(inputs_embeds.device)
392
393

        inputs_embeds *= np.sqrt(self.d_model_size)
thomwolf's avatar
thomwolf committed
394

395
        pos_embeds = self.pos_encoding[position_ids, :].to(inputs_embeds.device)
thomwolf's avatar
thomwolf committed
396

397
        hidden_states = inputs_embeds + pos_embeds + token_type_embeds
thomwolf's avatar
thomwolf committed
398

399
        hidden_states = self.dropout(hidden_states)
thomwolf's avatar
thomwolf committed
400

401
        output_shape = input_shape + (inputs_embeds.size(-1),)
thomwolf's avatar
thomwolf committed
402
403
404
405
406
        presents = ()
        all_hidden_states = ()
        all_attentions = []
        for i, (h, layer_past) in enumerate(zip(self.h, past)):
            if self.output_hidden_states:
407
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
408
409
410
            outputs = h(
                hidden_states, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
            )
411
            hidden_states, present = outputs[:2]
412
413
            if self.output_past:
                presents = presents + (present,)
thomwolf's avatar
thomwolf committed
414
415
416
417

            if self.output_attentions:
                all_attentions.append(outputs[2])

418
419
        hidden_states = self.layernorm(hidden_states)
        hidden_states = hidden_states.view(*output_shape)
keskarnitish's avatar
keskarnitish committed
420
        if self.output_hidden_states:
421
            all_hidden_states = all_hidden_states + (hidden_states,)
keskarnitish's avatar
keskarnitish committed
422

423
424
425
        outputs = (hidden_states,)
        if self.output_past:
            outputs = outputs + (presents,)
thomwolf's avatar
thomwolf committed
426
427
428
429
430
431
432
433
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
            all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
            outputs = outputs + (all_attentions,)
        return outputs
keskarnitish's avatar
keskarnitish committed
434
435


436
437
438
439
440
441
@add_start_docstrings(
    """The CTRL Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """,
    CTRL_START_DOCSTRING,
    CTRL_INPUTS_DOCSTRING,
)
keskarnitish's avatar
keskarnitish committed
442
443
444
445
446
447
class CTRLLMHeadModel(CTRLPreTrainedModel):
    r"""
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
LysandreJik's avatar
LysandreJik committed
448
            All labels set to ``-100`` are ignored (masked), the loss is only
keskarnitish's avatar
keskarnitish committed
449
450
451
452
453
454
455
456
            computed for labels in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Language modeling loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **past**:
457
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(2, batch_size, num_heads, sequence_length, embed_size_per_head)``:
keskarnitish's avatar
keskarnitish committed
458
            that contains pre-computed hidden-states (key and values in the attention blocks).
459
            Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
460
            should not be passed as input ids as they have already been computed.
keskarnitish's avatar
keskarnitish committed
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        import torch
        from transformers import CTRLTokenizer, CTRLLMHeadModel

        tokenizer = CTRLTokenizer.from_pretrained('ctrl')
        model = CTRLLMHeadModel.from_pretrained('ctrl')

477
        input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
keskarnitish's avatar
keskarnitish committed
478
479
480
481
        outputs = model(input_ids, labels=input_ids)
        loss, logits = outputs[:2]

    """
482

keskarnitish's avatar
keskarnitish committed
483
484
485
486
487
488
489
    def __init__(self, config):
        super(CTRLLMHeadModel, self).__init__(config)
        self.transformer = CTRLModel(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True)

        self.init_weights()

thomwolf's avatar
thomwolf committed
490
    def get_output_embeddings(self):
491
        return self.lm_head
thomwolf's avatar
thomwolf committed
492

493
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
patrickvonplaten's avatar
patrickvonplaten committed
494
        # inputs_ids should only be composed of last token if past is in kwargs and defined
495
496
497
498
499
500
        input_ids = input_ids[:, -1].unsqueeze(-1) if 'past' in kwargs and kwargs['past'] else input_ids

        inputs = {"input_ids": input_ids}
        inputs.update(kwargs)
        return inputs

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
thomwolf's avatar
thomwolf committed
521

keskarnitish's avatar
keskarnitish committed
522
523
524
525
526
527
528
529
530
531
532
        hidden_states = transformer_outputs[0]

        lm_logits = self.lm_head(hidden_states)

        outputs = (lm_logits,) + transformer_outputs[1:]

        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
LysandreJik's avatar
LysandreJik committed
533
            loss_fct = CrossEntropyLoss()
534
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
keskarnitish's avatar
keskarnitish committed
535
536
537
            outputs = (loss,) + outputs

        return outputs  # (loss), lm_logits, presents, (all hidden_states), (attentions)