"alphafold/model/prng_test.py" did not exist on "2f0d89e765051fc9e26fb4c52e5ad91bbb0e7e0b"
modeling_t5.py 60 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch T5 model. """


Aymeric Augustin's avatar
Aymeric Augustin committed
18
import copy
thomwolf's avatar
thomwolf committed
19
20
import math
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
21
import warnings
thomwolf's avatar
thomwolf committed
22
23

import torch
thomwolf's avatar
thomwolf committed
24
import torch.nn.functional as F
Aymeric Augustin's avatar
Aymeric Augustin committed
25
from torch import nn
26
from torch.nn import CrossEntropyLoss
thomwolf's avatar
thomwolf committed
27

Patrick von Platen's avatar
Patrick von Platen committed
28
from ...activations import ACT2FN
Sylvain Gugger's avatar
Sylvain Gugger committed
29
from ...file_utils import (
30
31
32
    DUMMY_INPUTS,
    DUMMY_MASK,
    add_start_docstrings,
33
    add_start_docstrings_to_model_forward,
34
35
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
36
from ...modeling_outputs import (
37
38
39
40
41
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
42
43
44
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_t5 import T5Config
Aymeric Augustin's avatar
Aymeric Augustin committed
45

thomwolf's avatar
thomwolf committed
46

Lysandre Debut's avatar
Lysandre Debut committed
47
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
48

49
_CONFIG_FOR_DOC = "T5Config"
50
51
_TOKENIZER_FOR_DOC = "T5Tokenizer"

thomwolf's avatar
thomwolf committed
52
####################################################
53
# This dict contains shortcut names and associated url
thomwolf's avatar
thomwolf committed
54
55
# for the pretrained weights provided with the models
####################################################
56
57
58
59
60
61
62
63
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "t5-small",
    "t5-base",
    "t5-large",
    "t5-3b",
    "t5-11b",
    # See all T5 models at https://huggingface.co/models?filter=t5
]
thomwolf's avatar
thomwolf committed
64

65

thomwolf's avatar
thomwolf committed
66
67
68
69
70
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
Lysandre's avatar
Lysandre committed
71
    """Load tf checkpoints in a pytorch model."""
thomwolf's avatar
thomwolf committed
72
73
    try:
        import re
74

thomwolf's avatar
thomwolf committed
75
76
77
        import numpy as np
        import tensorflow as tf
    except ImportError:
78
79
80
81
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
82
83
84
85
86
87
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
88
    tf_weights = {}
thomwolf's avatar
thomwolf committed
89
90
91
92
    for name, shape in init_vars:
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
93
        tf_weights[name] = array
thomwolf's avatar
thomwolf committed
94

95
    for txt_name in names:
96
        name = txt_name.split("/")
thomwolf's avatar
thomwolf committed
97
98
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
99
100
101
102
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
thomwolf's avatar
thomwolf committed
103
            logger.info("Skipping {}".format("/".join(name)))
104
105
            tf_weights.pop(txt_name, None)
            continue
106
        if "_slot_" in name[-1]:
107
108
            logger.info("Skipping {}".format("/".join(name)))
            tf_weights.pop(txt_name, None)
thomwolf's avatar
thomwolf committed
109
110
            continue
        pointer = model
111
        array = tf_weights[txt_name]
Patrick von Platen's avatar
Patrick von Platen committed
112

thomwolf's avatar
thomwolf committed
113
        for m_name in name:
114
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
115
                scope_names = re.split(r"_(\d+)", m_name)
thomwolf's avatar
thomwolf committed
116
            else:
117
118
                scope_names = [m_name]
            if scope_names[0] in ["kernel", "scale", "embedding"]:
119
                pointer = getattr(pointer, "weight")
Patrick von Platen's avatar
Patrick von Platen committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            elif scope_names[0] == "self_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[0]
            elif scope_names[0] == "enc_dec_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[1]
            elif scope_names[0] == "dense_relu_dense":
                pointer = getattr(pointer, "layer")
                pointer = pointer[2]
            elif scope_names[0] == "rms_norm":
                if hasattr(pointer, "layer_norm"):
                    pointer = getattr(pointer, "layer_norm")
                elif hasattr(pointer, "final_layer_norm"):
                    pointer = getattr(pointer, "final_layer_norm")
134
135
136
137
138
139
            elif scope_names[0] == "scale":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
Patrick von Platen's avatar
Patrick von Platen committed
140
141
142
143
            elif scope_names[0] == "decoder" and name[1] == "logits":
                continue
            elif scope_names[0] == "logits":
                pointer = getattr(pointer, "lm_head")
Patrick von Platen's avatar
Patrick von Platen committed
144
145
146
            elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
                pointer = getattr(pointer, f"wi_{scope_names[1]}")
                continue
thomwolf's avatar
thomwolf committed
147
148
            else:
                try:
149
                    pointer = getattr(pointer, scope_names[0])
thomwolf's avatar
thomwolf committed
150
151
152
                except AttributeError:
                    logger.info("Skipping {}".format("/".join(name)))
                    continue
153
154
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
155
                pointer = pointer[num]
156
        if scope_names[0] not in ["kernel", "scale", "embedding"]:
157
            pointer = getattr(pointer, "weight")
158
        if scope_names[0] != "embedding":
159
            logger.info("Transposing numpy weight of shape {} for {}".format(array.shape, name))
thomwolf's avatar
thomwolf committed
160
161
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
162
163
164
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
thomwolf's avatar
thomwolf committed
165
166
167
168
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        logger.info("Initialize PyTorch weight {}".format(name))
169
170
171
        pointer.data = torch.from_numpy(array.astype(np.float32))
        tf_weights.pop(txt_name, None)

172
    logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
thomwolf's avatar
thomwolf committed
173
174
175
176
177
178
179
180
181
    return model


####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
####################################################

182

thomwolf's avatar
thomwolf committed
183
184
class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
Sylvain Gugger's avatar
Sylvain Gugger committed
185
        """
186
        Construct a layernorm module in the T5 style No bias and no subtraction of mean.
thomwolf's avatar
thomwolf committed
187
        """
Julien Chaumond's avatar
Julien Chaumond committed
188
        super().__init__()
thomwolf's avatar
thomwolf committed
189
190
191
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

192
    def forward(self, hidden_states):
193
        # layer norm should always be calculated in float32
194
195
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
196

197
        # convert into float16 if necessary
198
        if self.weight.dtype == torch.float16:
199
200
            hidden_states = hidden_states.to(torch.float16)
        return self.weight * hidden_states
thomwolf's avatar
thomwolf committed
201
202


thomwolf's avatar
thomwolf committed
203
class T5DenseReluDense(nn.Module):
thomwolf's avatar
thomwolf committed
204
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
205
        super().__init__()
thomwolf's avatar
thomwolf committed
206
207
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
thomwolf's avatar
thomwolf committed
208
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
209
210

    def forward(self, hidden_states):
211
212
213
214
215
        hidden_states = self.wi(hidden_states)
        hidden_states = F.relu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
216
217


Patrick von Platen's avatar
Patrick von Platen committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
class T5DenseGatedGeluDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.gelu_act = ACT2FN["gelu_new"]

    def forward(self, hidden_states):
        hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states


thomwolf's avatar
thomwolf committed
236
237
class T5LayerFF(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
238
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
239
240
241
242
243
244
245
246
247
        if config.feed_forward_proj == "relu":
            self.DenseReluDense = T5DenseReluDense(config)
        elif config.feed_forward_proj == "gated-gelu":
            self.DenseReluDense = T5DenseGatedGeluDense(config)
        else:
            raise ValueError(
                f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
            )

thomwolf's avatar
thomwolf committed
248
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
249
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
250
251

    def forward(self, hidden_states):
252
253
254
255
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
256
257
258


class T5Attention(nn.Module):
259
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
260
        super().__init__()
thomwolf's avatar
thomwolf committed
261
        self.is_decoder = config.is_decoder
thomwolf's avatar
thomwolf committed
262
        self.has_relative_attention_bias = has_relative_attention_bias
thomwolf's avatar
thomwolf committed
263
264

        self.relative_attention_num_buckets = config.relative_attention_num_buckets
265
        self.d_model = config.d_model
266
        self.key_value_proj_dim = config.d_kv
thomwolf's avatar
thomwolf committed
267
268
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
269
        self.inner_dim = self.n_heads * self.key_value_proj_dim
thomwolf's avatar
thomwolf committed
270

271
        # Mesh TensorFlow initialization to avoid scaling before softmax
272
273
274
275
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
thomwolf's avatar
thomwolf committed
276

thomwolf's avatar
thomwolf committed
277
278
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
thomwolf's avatar
thomwolf committed
279
280
281
282
283
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
284
285
286
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
thomwolf's avatar
thomwolf committed
287
288
289
290
291
292
293
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
294
        self.inner_dim = self.key_value_proj_dim * self.n_heads
thomwolf's avatar
thomwolf committed
295
296
297
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
298
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
thomwolf's avatar
thomwolf committed
299
300
301
302
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

Sylvain Gugger's avatar
Sylvain Gugger committed
303
304
305
306
307
308
309
        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

thomwolf's avatar
thomwolf committed
310
311
312
313
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
314
            max_distance: an integer
Sylvain Gugger's avatar
Sylvain Gugger committed
315

thomwolf's avatar
thomwolf committed
316
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
317
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
thomwolf's avatar
thomwolf committed
318
        """
319
        relative_buckets = 0
thomwolf's avatar
thomwolf committed
320
321
        if bidirectional:
            num_buckets //= 2
322
323
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
thomwolf's avatar
thomwolf committed
324
        else:
325
326
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)
thomwolf's avatar
thomwolf committed
327
328
329

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
330
        is_small = relative_position < max_exact
thomwolf's avatar
thomwolf committed
331
332

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
333
334
335
336
        relative_postion_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
337
        ).to(torch.long)
338
339
340
        relative_postion_if_large = torch.min(
            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
        )
thomwolf's avatar
thomwolf committed
341

342
343
        relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
        return relative_buckets
thomwolf's avatar
thomwolf committed
344

345
    def compute_bias(self, query_length, key_length):
thomwolf's avatar
thomwolf committed
346
        """ Compute binned relative position bias """
347
348
349
350
351
352
        context_position = torch.arange(query_length, dtype=torch.long)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
353
354
            num_buckets=self.relative_attention_num_buckets,
        )
355
356
357
        relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
thomwolf's avatar
thomwolf committed
358
359
        return values

360
361
    def forward(
        self,
362
        hidden_states,
363
        mask=None,
364
        key_value_states=None,
365
        position_bias=None,
366
        past_key_value=None,
367
368
        head_mask=None,
        query_length=None,
369
        use_cache=False,
370
        output_attentions=False,
371
    ):
thomwolf's avatar
thomwolf committed
372
        """
373
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
thomwolf's avatar
thomwolf committed
374
        """
375
376
377
378
379
380
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length
381

382
        if past_key_value is not None:
383
            assert (
384
385
386
                len(past_key_value) == 2
            ), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
                len(past_key_value)
387
            )
388
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
389

390
        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
thomwolf's avatar
thomwolf committed
391

392
        def shape(states):
thomwolf's avatar
thomwolf committed
393
            """  projection """
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
            """  reshape """
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
            """ projects hidden states correctly to key/query states """
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
410

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
thomwolf's avatar
thomwolf committed
431

432
        # compute scores
Abel's avatar
Abel committed
433
        scores = torch.matmul(
434
435
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
thomwolf's avatar
thomwolf committed
436
437

        if position_bias is None:
thomwolf's avatar
thomwolf committed
438
            if not self.has_relative_attention_bias:
439
440
441
442
443
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
            else:
                position_bias = self.compute_bias(real_seq_length, key_length)
444
445
446

            # if key and values are already calculated
            # we want only the last query position bias
447
            if past_key_value is not None:
448
                position_bias = position_bias[:, :, -seq_length:, :]
449

thomwolf's avatar
thomwolf committed
450
            if mask is not None:
451
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
452

thomwolf's avatar
thomwolf committed
453
        scores += position_bias
454
455
456
457
458
459
        attn_weights = F.softmax(scores.float(), dim=-1).type_as(
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
        attn_weights = F.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
460
461
462

        # Mask heads if we want to
        if head_mask is not None:
463
            attn_weights = attn_weights * head_mask
thomwolf's avatar
thomwolf committed
464

465
466
        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)
thomwolf's avatar
thomwolf committed
467

468
469
        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
470

471
        if output_attentions:
472
            outputs = outputs + (attn_weights,)
thomwolf's avatar
thomwolf committed
473
        return outputs
thomwolf's avatar
thomwolf committed
474
475
476


class T5LayerSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
477
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
478
        super().__init__()
479
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
480
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
481
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
482

483
    def forward(
484
485
486
487
488
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        head_mask=None,
489
        past_key_value=None,
490
        use_cache=False,
491
        output_attentions=False,
492
    ):
493
        normed_hidden_states = self.layer_norm(hidden_states)
494
        attention_output = self.SelfAttention(
495
            normed_hidden_states,
496
497
498
            mask=attention_mask,
            position_bias=position_bias,
            head_mask=head_mask,
499
            past_key_value=past_key_value,
500
            use_cache=use_cache,
501
            output_attentions=output_attentions,
502
        )
503
504
        hidden_states = hidden_states + self.dropout(attention_output[0])
        outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them
thomwolf's avatar
thomwolf committed
505
        return outputs
thomwolf's avatar
thomwolf committed
506
507


thomwolf's avatar
thomwolf committed
508
class T5LayerCrossAttention(nn.Module):
509
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
510
        super().__init__()
511
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
thomwolf's avatar
thomwolf committed
512
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
513
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
514

515
516
517
    def forward(
        self,
        hidden_states,
518
        key_value_states,
519
520
521
        attention_mask=None,
        position_bias=None,
        head_mask=None,
522
        past_key_value=None,
523
        use_cache=False,
524
        query_length=None,
525
        output_attentions=False,
526
    ):
527
        normed_hidden_states = self.layer_norm(hidden_states)
528
        attention_output = self.EncDecAttention(
529
            normed_hidden_states,
530
            mask=attention_mask,
531
            key_value_states=key_value_states,
532
533
            position_bias=position_bias,
            head_mask=head_mask,
534
            past_key_value=past_key_value,
535
            use_cache=use_cache,
536
            query_length=query_length,
537
            output_attentions=output_attentions,
538
        )
539
        layer_output = hidden_states + self.dropout(attention_output[0])
thomwolf's avatar
thomwolf committed
540
541
542
543
544
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


class T5Block(nn.Module):
thomwolf's avatar
thomwolf committed
545
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
546
        super().__init__()
thomwolf's avatar
thomwolf committed
547
        self.is_decoder = config.is_decoder
548
549
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
thomwolf's avatar
thomwolf committed
550
        if self.is_decoder:
551
            self.layer.append(T5LayerCrossAttention(config))
552
553

        self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
554

555
556
557
558
559
560
561
562
563
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        head_mask=None,
564
        past_key_value=None,
565
        use_cache=False,
566
        output_attentions=False,
567
        return_dict=True,
568
    ):
569

570
571
572
        if past_key_value is not None:
            assert self.is_decoder, "Only decoder can use `past_key_values`"
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
573
574

            error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
575
576
577
                expected_num_past_key_values,
                "2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "",
                len(past_key_value),
578
            )
579
            assert len(past_key_value) == expected_num_past_key_values, error_message
580

581
582
            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
583
        else:
584
            self_attn_past_key_value, cross_attn_past_key_value = None, None
585

586
        self_attention_outputs = self.layer[0](
587
588
589
590
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
            head_mask=head_mask,
591
            past_key_value=self_attn_past_key_value,
592
            use_cache=use_cache,
593
            output_attentions=output_attentions,
594
        )
595
596
597
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

598
599
        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
600
601
602
603
604
605
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None
thomwolf's avatar
thomwolf committed
606

607
608
            cross_attention_outputs = self.layer[1](
                hidden_states,
609
                key_value_states=encoder_hidden_states,
610
611
612
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
                head_mask=head_mask,
613
                past_key_value=cross_attn_past_key_value,
614
                query_length=query_length,
615
                use_cache=use_cache,
616
                output_attentions=output_attentions,
617
            )
thomwolf's avatar
thomwolf committed
618
            hidden_states = cross_attention_outputs[0]
619
620
621
622
623
624
625
626
627
628
            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)
        outputs = (hidden_states,)
thomwolf's avatar
thomwolf committed
629

630
631
        outputs = outputs + (present_key_value_state,) + attention_outputs
        return outputs  # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
thomwolf's avatar
thomwolf committed
632
633


thomwolf's avatar
thomwolf committed
634
class T5PreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
635
636
637
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
thomwolf's avatar
thomwolf committed
638
    """
639

thomwolf's avatar
thomwolf committed
640
641
642
643
    config_class = T5Config
    load_tf_weights = load_tf_weights_in_t5
    base_model_prefix = "transformer"

644
645
646
647
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
648
649
        dummy_inputs = {
            "decoder_input_ids": input_ids,
650
            "input_ids": input_ids,
651
652
            "decoder_attention_mask": input_mask,
        }
653
654
        return dummy_inputs

thomwolf's avatar
thomwolf committed
655
656
    def _init_weights(self, module):
        """ Initialize the weights """
657
        factor = self.config.initializer_factor  # Used for testing weights initialization
thomwolf's avatar
thomwolf committed
658
        if isinstance(module, T5LayerNorm):
659
            module.weight.data.fill_(factor * 1.0)
660
        elif isinstance(module, (T5Model, T5ForConditionalGeneration)):
661
662
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
663
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
664
665
666
667
        elif isinstance(module, T5DenseReluDense):
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
668
669
            module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi, "bias") and module.wi.bias is not None:
670
                module.wi.bias.data.zero_()
671
672
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
673
                module.wo.bias.data.zero_()
Patrick von Platen's avatar
Patrick von Platen committed
674
675
676
677
678
679
680
681
682
683
        elif isinstance(module, T5DenseGatedGeluDense):
            module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
                module.wi_0.bias.data.zero_()
            module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
                module.wi_1.bias.data.zero_()
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
                module.wo.bias.data.zero_()
684
685
686
687
        elif isinstance(module, T5Attention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
688
            key_value_proj_dim = self.config.d_kv
689
            n_heads = self.config.num_heads
690
            module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
691
692
            module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
            module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
693
            module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
694
            if module.has_relative_attention_bias:
695
                module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
thomwolf's avatar
thomwolf committed
696

697
698
699
700
701
702
703
704
705
706
707
708
709
710
    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

        assert (
            decoder_start_token_id is not None
        ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"

        # shift inputs to the right
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
        shifted_input_ids[..., 0] = decoder_start_token_id

        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
Sylvain Gugger's avatar
Sylvain Gugger committed
711
        # replace possible -100 values in labels by `pad_token_id`
712
713
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

714
        assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
715
716
717

        return shifted_input_ids

thomwolf's avatar
thomwolf committed
718
719

class T5Stack(T5PreTrainedModel):
720
    def __init__(self, config, embed_tokens=None):
Julien Chaumond's avatar
Julien Chaumond committed
721
        super().__init__(config)
722
723

        self.embed_tokens = embed_tokens
thomwolf's avatar
thomwolf committed
724
725
        self.is_decoder = config.is_decoder

726
727
728
        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
        )
thomwolf's avatar
thomwolf committed
729
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
730
731
732
        self.dropout = nn.Dropout(config.dropout_rate)

        self.init_weights()
thomwolf's avatar
thomwolf committed
733

734
735
736
737
738
739
740
741
742
    def get_input_embeddings(self):
        return self.embed_tokens

    def get_output_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

743
744
    def forward(
        self,
745
        input_ids=None,
746
747
748
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
749
        inputs_embeds=None,
750
        head_mask=None,
751
        past_key_values=None,
752
        use_cache=None,
753
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
754
        output_hidden_states=None,
755
        return_dict=None,
756
    ):
thomwolf's avatar
thomwolf committed
757

758
        use_cache = use_cache if use_cache is not None else self.config.use_cache
759
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
760
761
762
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
763
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
764

765
        if input_ids is not None and inputs_embeds is not None:
766
767
768
769
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
                f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
            )
770
771
772
773
774
775
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
776
777
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
778
779

        if inputs_embeds is None:
780
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
781
782
783
784
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape

785
786
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
787

788
        if use_cache is True:
789
790
791
            assert self.is_decoder, ":obj:`use_cache` can only be set to `True` if {} is used as a decoder".format(
                self
            )
792

thomwolf's avatar
thomwolf committed
793
        if attention_mask is None:
794
795
            attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
thomwolf's avatar
thomwolf committed
796
            encoder_seq_length = encoder_hidden_states.shape[1]
797
798
799
            encoder_attention_mask = torch.ones(
                batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
            )
thomwolf's avatar
thomwolf committed
800

801
802
803
        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)
804

thomwolf's avatar
thomwolf committed
805
        # ourselves in which case we just need to make it broadcastable to all heads.
806
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
thomwolf's avatar
thomwolf committed
807

808
        if self.is_decoder and encoder_attention_mask is not None:
809
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
thomwolf's avatar
thomwolf committed
810
811
        else:
            encoder_extended_attention_mask = None
thomwolf's avatar
thomwolf committed
812
813

        # Prepare head mask if needed
814
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
815
816
817
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
818
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
thomwolf's avatar
thomwolf committed
819
        position_bias = None
thomwolf's avatar
thomwolf committed
820
        encoder_decoder_position_bias = None
thomwolf's avatar
thomwolf committed
821

822
        hidden_states = self.dropout(inputs_embeds)
823

824
        for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
Joseph Liu's avatar
Joseph Liu committed
825
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
826
827
                all_hidden_states = all_hidden_states + (hidden_states,)

828
829
830
831
832
833
834
835
            layer_outputs = layer_module(
                hidden_states,
                attention_mask=extended_attention_mask,
                position_bias=position_bias,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_extended_attention_mask,
                encoder_decoder_position_bias=encoder_decoder_position_bias,
                head_mask=head_mask[i],
836
                past_key_value=past_key_value,
837
                use_cache=use_cache,
838
                output_attentions=output_attentions,
839
            )
thomwolf's avatar
thomwolf committed
840
            # layer_outputs is a tuple with:
841
842
            # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
            hidden_states, present_key_value_state = layer_outputs[:2]
843

844
845
846
847
848
849
            # We share the position biases between the layers - the first layer store them
            # layer_outputs = hidden-states, key-value-states (self-attention weights),
            # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
850
            # append next layer key value states
851
852
            if use_cache:
                present_key_value_states = present_key_value_states + (present_key_value_state,)
thomwolf's avatar
thomwolf committed
853

854
            if output_attentions:
855
                all_attentions = all_attentions + (layer_outputs[3],)
856
                if self.is_decoder:
857
                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
thomwolf's avatar
thomwolf committed
858
859

        hidden_states = self.final_layer_norm(hidden_states)
thomwolf's avatar
thomwolf committed
860
        hidden_states = self.dropout(hidden_states)
thomwolf's avatar
thomwolf committed
861
862

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
863
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
864
865
            all_hidden_states = all_hidden_states + (hidden_states,)

866
        if not return_dict:
867
868
            return tuple(
                v
869
870
871
872
873
874
875
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
876
877
                if v is not None
            )
878
        return BaseModelOutputWithPastAndCrossAttentions(
879
880
881
882
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
883
            cross_attentions=all_cross_attentions,
884
        )
thomwolf's avatar
thomwolf committed
885
886


887
T5_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
888

889
890
    The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
    <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
Sylvain Gugger's avatar
Sylvain Gugger committed
891
892
    Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text
    denoising generative setting.
thomwolf's avatar
thomwolf committed
893

Sylvain Gugger's avatar
Sylvain Gugger committed
894
895
896
897
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
898
899
900
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
thomwolf's avatar
thomwolf committed
901
902

    Parameters:
903
        config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
904
905
906
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
thomwolf's avatar
thomwolf committed
907
908
909
"""

T5_INPUTS_DOCSTRING = r"""
Patrick von Platen's avatar
Patrick von Platen committed
910
911
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
912
913
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.
Sylvain Gugger's avatar
Sylvain Gugger committed
914

Sylvain Gugger's avatar
Sylvain Gugger committed
915
916
917
            Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            detail.
Sylvain Gugger's avatar
Sylvain Gugger committed
918

Sylvain Gugger's avatar
Sylvain Gugger committed
919
920
            To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
            <./t5.html#training>`__.
921
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
922
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
923
924

            - 1 for tokens that are **not masked**,
925
            - 0 for tokens that are **masked**.
Sylvain Gugger's avatar
Sylvain Gugger committed
926
927

            `What are attention masks? <../glossary.html#attention-mask>`__
928
        decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
929
            Provide for sequence to sequence training. T5 uses the :obj:`pad_token_id` as the starting token for
Sylvain Gugger's avatar
Sylvain Gugger committed
930
931
            :obj:`decoder_input_ids` generation. If :obj:`past_key_values` is used, optionally only the last
            :obj:`decoder_input_ids` have to be input (see :obj:`past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
932

Sylvain Gugger's avatar
Sylvain Gugger committed
933
934
935
            To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
            <./t5.html#training>`__. If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset,
            :obj:`decoder_input_ids` takes the value of :obj:`input_ids`.
936
        decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
937
938
            Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
            also be used by default.
939
        encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
940
941
942
943
            Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
            `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
            sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
944
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
945
946
947
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
948
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
Sylvain Gugger's avatar
Sylvain Gugger committed
949
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
950
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
951
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
952
953
954
955

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

956
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Patrick von Platen's avatar
Patrick von Platen committed
957
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
958
959
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
960
        decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
961
            Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
Sylvain Gugger's avatar
Sylvain Gugger committed
962
963
964
            representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
            have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
            :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
Sylvain Gugger's avatar
Sylvain Gugger committed
965

Sylvain Gugger's avatar
Sylvain Gugger committed
966
967
            If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
            takes the value of :obj:`inputs_embeds`.
Sylvain Gugger's avatar
Sylvain Gugger committed
968

969
970
971
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
972

973
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
974
975
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
976
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
977
978
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
979
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
980
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
thomwolf's avatar
thomwolf committed
981
982
"""

983
984
985
986
987

@add_start_docstrings(
    "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
    T5_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
988
class T5Model(T5PreTrainedModel):
989
990
991
992
993
994
    authorized_missing_keys = [
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

995
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
996
        super().__init__(config)
thomwolf's avatar
thomwolf committed
997
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
thomwolf's avatar
thomwolf committed
998
999

        encoder_config = copy.deepcopy(config)
1000
        encoder_config.use_cache = False
1001
        encoder_config.is_encoder_decoder = False
1002
        self.encoder = T5Stack(encoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1003

thomwolf's avatar
thomwolf committed
1004
1005
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1006
        decoder_config.is_encoder_decoder = False
1007
        decoder_config.num_layers = config.num_decoder_layers
1008
        self.decoder = T5Stack(decoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1009
1010
1011
1012

        self.init_weights()

    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
1013
        return self.shared
thomwolf's avatar
thomwolf committed
1014
1015

    def set_input_embeddings(self, new_embeddings):
thomwolf's avatar
thomwolf committed
1016
        self.shared = new_embeddings
1017
1018
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
thomwolf's avatar
thomwolf committed
1019

1020
1021
1022
1023
1024
1025
    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

thomwolf's avatar
thomwolf committed
1026
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
1027
1028
1029
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
thomwolf's avatar
thomwolf committed
1030
1031
1032
1033
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

1034
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1035
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1036
1037
1038
1039
1040
1041
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1042
        encoder_outputs=None,
1043
        past_key_values=None,
1044
        head_mask=None,
1045
1046
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1047
        use_cache=None,
1048
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1049
        output_hidden_states=None,
1050
        return_dict=None,
1051
        **kwargs,
1052
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1053
        r"""
Lysandre's avatar
Lysandre committed
1054
        Returns:
Patrick von Platen's avatar
Patrick von Platen committed
1055

Lysandre's avatar
Lysandre committed
1056
        Example::
1057

Lysandre's avatar
Lysandre committed
1058
            >>> from transformers import T5Tokenizer, T5Model
Patrick von Platen's avatar
Patrick von Platen committed
1059

Lysandre's avatar
Lysandre committed
1060
1061
            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5Model.from_pretrained('t5-small')
Patrick von Platen's avatar
Patrick von Platen committed
1062

1063
1064
            >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
            >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
1065
            >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
Patrick von Platen's avatar
Patrick von Platen committed
1066

1067
            >>> last_hidden_states = outputs.last_hidden_state
Patrick von Platen's avatar
Patrick von Platen committed
1068
        """
1069
1070
        if "decoder_past_key_value_states" in kwargs:
            warnings.warn(
1071
                "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
1072
1073
                FutureWarning,
            )
1074
1075
1076
1077
1078
1079
1080
            past_key_values = kwargs.pop("decoder_past_key_value_states")
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_values")
1081
1082
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

1083
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1084
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
thomwolf's avatar
thomwolf committed
1085
1086

        # Encode if needed (training, first prediction pass)
1087
1088
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
1089
1090
1091
1092
1093
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1094
                output_hidden_states=output_hidden_states,
1095
                return_dict=return_dict,
1096
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
1097
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1098
1099
1100
1101
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1102
            )
thomwolf's avatar
thomwolf committed
1103

1104
        hidden_states = encoder_outputs[0]
thomwolf's avatar
thomwolf committed
1105

1106
1107
1108
1109
1110
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1111
            past_key_values=past_key_values,
1112
1113
1114
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=head_mask,
1115
            use_cache=use_cache,
1116
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1117
            output_hidden_states=output_hidden_states,
1118
            return_dict=return_dict,
1119
        )
thomwolf's avatar
thomwolf committed
1120

1121
        if not return_dict:
1122
1123
1124
1125
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
1126
            past_key_values=decoder_outputs.past_key_values,
1127
1128
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1129
            cross_attentions=decoder_outputs.cross_attentions,
1130
1131
1132
1133
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1134
1135


Patrick von Platen's avatar
Patrick von Platen committed
1136
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
1137
class T5ForConditionalGeneration(T5PreTrainedModel):
1138
1139
1140
1141
1142
1143
    authorized_missing_keys = [
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]
1144

thomwolf's avatar
thomwolf committed
1145
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1146
        super().__init__(config)
1147
        self.model_dim = config.d_model
thomwolf's avatar
thomwolf committed
1148

1149
1150
1151
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
1152
        encoder_config.use_cache = False
1153
        encoder_config.is_encoder_decoder = False
1154
        self.encoder = T5Stack(encoder_config, self.shared)
1155
1156
1157

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1158
        decoder_config.is_encoder_decoder = False
1159
        decoder_config.num_layers = config.num_decoder_layers
1160
        self.decoder = T5Stack(decoder_config, self.shared)
1161

1162
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
1163
1164
1165

        self.init_weights()

1166
1167
1168
1169
1170
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
1171
1172
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
1173

thomwolf's avatar
thomwolf committed
1174
1175
1176
    def get_output_embeddings(self):
        return self.lm_head

1177
1178
    def get_encoder(self):
        return self.encoder
thomwolf's avatar
thomwolf committed
1179

1180
1181
1182
    def get_decoder(self):
        return self.decoder

1183
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1184
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1185
1186
1187
1188
1189
1190
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1191
        encoder_outputs=None,
1192
        past_key_values=None,
1193
        head_mask=None,
1194
1195
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1196
1197
        labels=None,
        use_cache=None,
1198
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1199
        output_hidden_states=None,
1200
        return_dict=None,
1201
        **kwargs,
1202
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1203
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1204
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1205
1206
1207
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
1208
1209
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
Lysandre's avatar
Lysandre committed
1210
1211
1212
1213
1214
1215
1216
1217

        Returns:

        Examples::

            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration

            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
1218
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
1219
1220

            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
1221
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
1222
            >>> outputs = model(input_ids=input_ids, labels=labels)
Lysandre's avatar
Lysandre committed
1223
1224
1225
            >>> loss = outputs.loss
            >>> logits = outputs.logits

1226
            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids  # Batch size 1
Lysandre's avatar
Lysandre committed
1227
            >>> outputs = model.generate(input_ids)
Patrick von Platen's avatar
Patrick von Platen committed
1228
        """
1229

Sylvain Gugger's avatar
Sylvain Gugger committed
1230
1231
1232
        if "lm_labels" in kwargs:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1233
                FutureWarning,
Sylvain Gugger's avatar
Sylvain Gugger committed
1234
1235
            )
            labels = kwargs.pop("lm_labels")
1236
1237
        if "decoder_past_key_value_states" in kwargs:
            warnings.warn(
1238
1239
1240
1241
1242
1243
1244
                "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_value_states")
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
1245
1246
                FutureWarning,
            )
1247
            past_key_values = kwargs.pop("decoder_past_key_values")
Sylvain Gugger's avatar
Sylvain Gugger committed
1248
1249
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

1250
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1251
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1252

1253
        # Encode if needed (training, first prediction pass)
1254
        if encoder_outputs is None:
thomwolf's avatar
thomwolf committed
1255
            # Convert encoder inputs in embeddings if needed
1256
            encoder_outputs = self.encoder(
1257
1258
1259
1260
1261
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1262
                output_hidden_states=output_hidden_states,
1263
                return_dict=return_dict,
1264
            )
1265
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1266
1267
1268
1269
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1270
            )
thomwolf's avatar
thomwolf committed
1271

1272
        hidden_states = encoder_outputs[0]
1273

Sylvain Gugger's avatar
Sylvain Gugger committed
1274
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1275
            # get decoder inputs from shifting lm labels to the right
Sylvain Gugger's avatar
Sylvain Gugger committed
1276
            decoder_input_ids = self._shift_right(labels)
1277

1278
1279
        # If decoding with past key value states, only the last tokens
        # should be given as an input
1280
        if past_key_values is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1281
            assert labels is None, "Decoder should not use cached key value states when training."
1282
1283
1284
1285
1286
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

1287
        # Decode
1288
1289
1290
1291
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1292
            past_key_values=past_key_values,
1293
1294
1295
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=head_mask,
1296
            use_cache=use_cache,
1297
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1298
            output_hidden_states=output_hidden_states,
1299
            return_dict=return_dict,
1300
        )
1301
1302

        sequence_output = decoder_outputs[0]
Patrick von Platen's avatar
Patrick von Platen committed
1303
1304
1305
1306
1307
1308

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)

thomwolf's avatar
thomwolf committed
1309
        lm_logits = self.lm_head(sequence_output)
thomwolf's avatar
thomwolf committed
1310

1311
        loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1312
        if labels is not None:
Lysandre's avatar
Lysandre committed
1313
            loss_fct = CrossEntropyLoss(ignore_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1314
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1315
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
thomwolf's avatar
thomwolf committed
1316

1317
        if not return_dict:
1318
1319
1320
1321
1322
1323
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
1324
            past_key_values=decoder_outputs.past_key_values,
1325
1326
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1327
            cross_attentions=decoder_outputs.cross_attentions,
1328
1329
1330
1331
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
1332

1333
1334
1335
    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
    ):
1336
1337
1338
1339
1340

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

1341
1342
        return {
            "decoder_input_ids": input_ids,
1343
            "past_key_values": past,
1344
1345
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
1346
            "use_cache": use_cache,
1347
1348
1349
        }

    def _reorder_cache(self, past, beam_idx):
1350
1351
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
1352
        if past is None:
1353
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1354
1355
1356
            return past

        reordered_decoder_past = ()
1357
        for layer_past_states in past:
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx),
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1371
        return reordered_decoder_past