modeling_t5.py 103 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" PyTorch T5 model."""
thomwolf's avatar
thomwolf committed
16
17


Aymeric Augustin's avatar
Aymeric Augustin committed
18
import copy
thomwolf's avatar
thomwolf committed
19
20
import math
import os
21
import warnings
22
from typing import List, Optional, Tuple, Union
thomwolf's avatar
thomwolf committed
23
24

import torch
Aymeric Augustin's avatar
Aymeric Augustin committed
25
from torch import nn
26
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
thomwolf's avatar
thomwolf committed
27

Patrick von Platen's avatar
Patrick von Platen committed
28
from ...activations import ACT2FN
Sylvain Gugger's avatar
Sylvain Gugger committed
29
from ...modeling_outputs import (
30
31
32
33
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
34
    Seq2SeqQuestionAnsweringModelOutput,
35
    Seq2SeqSequenceClassifierOutput,
36
)
37
from ...modeling_utils import PreTrainedModel
38
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
39
40
41
42
43
44
45
46
47
from ...utils import (
    DUMMY_INPUTS,
    DUMMY_MASK,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_torch_fx_proxy,
    logging,
    replace_return_docstrings,
)
48
from ...utils.model_parallel_utils import assert_device_map, get_device_map
Sylvain Gugger's avatar
Sylvain Gugger committed
49
from .configuration_t5 import T5Config
Aymeric Augustin's avatar
Aymeric Augustin committed
50

thomwolf's avatar
thomwolf committed
51

Lysandre Debut's avatar
Lysandre Debut committed
52
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
53

54
_CONFIG_FOR_DOC = "T5Config"
55
_CHECKPOINT_FOR_DOC = "t5-small"
56

thomwolf's avatar
thomwolf committed
57
####################################################
58
# This dict contains ids and associated url
thomwolf's avatar
thomwolf committed
59
60
# for the pretrained weights provided with the models
####################################################
61
62
63
64
65
66
67
68
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "t5-small",
    "t5-base",
    "t5-large",
    "t5-3b",
    "t5-11b",
    # See all T5 models at https://huggingface.co/models?filter=t5
]
thomwolf's avatar
thomwolf committed
69

70

thomwolf's avatar
thomwolf committed
71
72
73
74
75
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
Lysandre's avatar
Lysandre committed
76
    """Load tf checkpoints in a pytorch model."""
thomwolf's avatar
thomwolf committed
77
78
    try:
        import re
79

thomwolf's avatar
thomwolf committed
80
81
82
        import numpy as np
        import tensorflow as tf
    except ImportError:
83
84
85
86
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
87
88
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
89
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
thomwolf's avatar
thomwolf committed
90
91
92
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
93
    tf_weights = {}
thomwolf's avatar
thomwolf committed
94
    for name, shape in init_vars:
95
        logger.info(f"Loading TF weight {name} with shape {shape}")
thomwolf's avatar
thomwolf committed
96
97
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
98
        tf_weights[name] = array
thomwolf's avatar
thomwolf committed
99

100
    for txt_name in names:
101
        name = txt_name.split("/")
thomwolf's avatar
thomwolf committed
102
103
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
104
105
106
107
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
108
            logger.info(f"Skipping {'/'.join(name)}")
109
110
            tf_weights.pop(txt_name, None)
            continue
111
        if "_slot_" in name[-1]:
112
            logger.info(f"Skipping {'/'.join(name)}")
113
            tf_weights.pop(txt_name, None)
thomwolf's avatar
thomwolf committed
114
115
            continue
        pointer = model
116
        array = tf_weights[txt_name]
Patrick von Platen's avatar
Patrick von Platen committed
117

thomwolf's avatar
thomwolf committed
118
        for m_name in name:
119
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
120
                scope_names = re.split(r"_(\d+)", m_name)
thomwolf's avatar
thomwolf committed
121
            else:
122
123
                scope_names = [m_name]
            if scope_names[0] in ["kernel", "scale", "embedding"]:
124
                pointer = getattr(pointer, "weight")
Patrick von Platen's avatar
Patrick von Platen committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            elif scope_names[0] == "self_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[0]
            elif scope_names[0] == "enc_dec_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[1]
            elif scope_names[0] == "dense_relu_dense":
                pointer = getattr(pointer, "layer")
                pointer = pointer[2]
            elif scope_names[0] == "rms_norm":
                if hasattr(pointer, "layer_norm"):
                    pointer = getattr(pointer, "layer_norm")
                elif hasattr(pointer, "final_layer_norm"):
                    pointer = getattr(pointer, "final_layer_norm")
139
140
141
142
143
144
            elif scope_names[0] == "scale":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
Patrick von Platen's avatar
Patrick von Platen committed
145
146
147
148
            elif scope_names[0] == "decoder" and name[1] == "logits":
                continue
            elif scope_names[0] == "logits":
                pointer = getattr(pointer, "lm_head")
Patrick von Platen's avatar
Patrick von Platen committed
149
150
151
            elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
                pointer = getattr(pointer, f"wi_{scope_names[1]}")
                continue
thomwolf's avatar
thomwolf committed
152
153
            else:
                try:
154
                    pointer = getattr(pointer, scope_names[0])
thomwolf's avatar
thomwolf committed
155
                except AttributeError:
156
                    logger.info(f"Skipping {'/'.join(name)}")
thomwolf's avatar
thomwolf committed
157
                    continue
158
159
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
160
                pointer = pointer[num]
161
        if scope_names[0] not in ["kernel", "scale", "embedding"]:
162
            pointer = getattr(pointer, "weight")
163
        if scope_names[0] != "embedding":
164
            logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
thomwolf's avatar
thomwolf committed
165
166
            array = np.transpose(array)
        try:
167
168
            if pointer.shape != array.shape:
                raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
thomwolf's avatar
thomwolf committed
169
170
171
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
172
        logger.info(f"Initialize PyTorch weight {name}")
173
174
175
        pointer.data = torch.from_numpy(array.astype(np.float32))
        tf_weights.pop(txt_name, None)

176
    logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
thomwolf's avatar
thomwolf committed
177
178
179
180
181
182
    return model


####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
183
# - PreTrainedModel for the models (it-self a sub-class of nn.Module)
thomwolf's avatar
thomwolf committed
184
####################################################
185
PARALLELIZE_DOCSTRING = r"""
Stas Bekman's avatar
Stas Bekman committed
186
187
    This is an experimental feature and is a subject to change at a moment's notice.

188
189
190
191
    Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
    it will evenly distribute blocks across all devices.

    Args:
192
        device_map (`Dict[int, list]`, optional, defaults to None):
193
194
195
196
197
198
199
200
201
202
203
            A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
            automatically mapped to the first device (for esoteric reasons). That means that the first device should
            have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
            following number of attention modules:

                - t5-small: 6
                - t5-base: 12
                - t5-large: 24
                - t5-3b: 24
                - t5-11b: 24

204
    Example:
205

206
207
    ```python
    # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
Sylvain Gugger's avatar
Sylvain Gugger committed
208
209
210
211
212
213
214
    model = T5ForConditionalGeneration.from_pretrained("t5-3b")
    device_map = {
        0: [0, 1, 2],
        1: [3, 4, 5, 6, 7, 8, 9],
        2: [10, 11, 12, 13, 14, 15, 16],
        3: [17, 18, 19, 20, 21, 22, 23],
    }
215
216
    model.parallelize(device_map)
    ```
217
218
219
220
"""
DEPARALLELIZE_DOCSTRING = r"""
    Moves the model to cpu from a model parallel state.

221
    Example:
222

223
224
    ```python
    # On a 4 GPU machine with t5-3b:
Sylvain Gugger's avatar
Sylvain Gugger committed
225
226
227
228
229
230
231
232
233
    model = T5ForConditionalGeneration.from_pretrained("t5-3b")
    device_map = {
        0: [0, 1, 2],
        1: [3, 4, 5, 6, 7, 8, 9],
        2: [10, 11, 12, 13, 14, 15, 16],
        3: [17, 18, 19, 20, 21, 22, 23],
    }
    model.parallelize(device_map)  # Splits the model across several devices
    model.deparallelize()  # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
234
    ```
235
"""
thomwolf's avatar
thomwolf committed
236

237

thomwolf's avatar
thomwolf committed
238
239
class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
Sylvain Gugger's avatar
Sylvain Gugger committed
240
        """
241
        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
thomwolf's avatar
thomwolf committed
242
        """
Julien Chaumond's avatar
Julien Chaumond committed
243
        super().__init__()
thomwolf's avatar
thomwolf committed
244
245
246
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

247
    def forward(self, hidden_states):
248
249
250
251
252
        # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
        # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
        # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
        # half-precision inputs is done in fp32

253
254
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
255

256
257
258
259
        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

260
        return self.weight * hidden_states
thomwolf's avatar
thomwolf committed
261
262


263
264
265
266
267
268
269
270
271
272
273
274
275
try:
    from apex.normalization import FusedRMSNorm

    T5LayerNorm = FusedRMSNorm  # noqa

    logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
except ImportError:
    # using the normal T5LayerNorm
    pass
except Exception:
    logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
    pass

276
277
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)

278

DanielHesslow's avatar
DanielHesslow committed
279
class T5DenseActDense(nn.Module):
280
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
281
        super().__init__()
thomwolf's avatar
thomwolf committed
282
283
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
thomwolf's avatar
thomwolf committed
284
        self.dropout = nn.Dropout(config.dropout_rate)
DanielHesslow's avatar
DanielHesslow committed
285
        self.act = ACT2FN[config.dense_act_fn]
thomwolf's avatar
thomwolf committed
286
287

    def forward(self, hidden_states):
288
        hidden_states = self.wi(hidden_states)
DanielHesslow's avatar
DanielHesslow committed
289
        hidden_states = self.act(hidden_states)
290
        hidden_states = self.dropout(hidden_states)
291
292
293
294
295
        if (
            isinstance(self.wo.weight, torch.Tensor)
            and hidden_states.dtype != self.wo.weight.dtype
            and self.wo.weight.dtype != torch.int8
        ):
296
            hidden_states = hidden_states.to(self.wo.weight.dtype)
297
298
        hidden_states = self.wo(hidden_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
299
300


DanielHesslow's avatar
DanielHesslow committed
301
class T5DenseGatedActDense(nn.Module):
302
    def __init__(self, config: T5Config):
Patrick von Platen's avatar
Patrick von Platen committed
303
304
305
306
307
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
DanielHesslow's avatar
DanielHesslow committed
308
        self.act = ACT2FN[config.dense_act_fn]
Patrick von Platen's avatar
Patrick von Platen committed
309
310

    def forward(self, hidden_states):
DanielHesslow's avatar
DanielHesslow committed
311
        hidden_gelu = self.act(self.wi_0(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
312
313
314
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
315
316
317

        # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
        # See https://github.com/huggingface/transformers/issues/20287
318
        # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
319
320
321
322
323
        if (
            isinstance(self.wo.weight, torch.Tensor)
            and hidden_states.dtype != self.wo.weight.dtype
            and self.wo.weight.dtype != torch.int8
        ):
324
325
            hidden_states = hidden_states.to(self.wo.weight.dtype)

Patrick von Platen's avatar
Patrick von Platen committed
326
327
328
329
        hidden_states = self.wo(hidden_states)
        return hidden_states


thomwolf's avatar
thomwolf committed
330
class T5LayerFF(nn.Module):
331
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
332
        super().__init__()
DanielHesslow's avatar
DanielHesslow committed
333
334
        if config.is_gated_act:
            self.DenseReluDense = T5DenseGatedActDense(config)
Patrick von Platen's avatar
Patrick von Platen committed
335
        else:
DanielHesslow's avatar
DanielHesslow committed
336
            self.DenseReluDense = T5DenseActDense(config)
Patrick von Platen's avatar
Patrick von Platen committed
337

thomwolf's avatar
thomwolf committed
338
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
339
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
340
341

    def forward(self, hidden_states):
342
343
344
345
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
346
347
348


class T5Attention(nn.Module):
349
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
350
        super().__init__()
thomwolf's avatar
thomwolf committed
351
        self.is_decoder = config.is_decoder
thomwolf's avatar
thomwolf committed
352
        self.has_relative_attention_bias = has_relative_attention_bias
thomwolf's avatar
thomwolf committed
353
        self.relative_attention_num_buckets = config.relative_attention_num_buckets
354
        self.relative_attention_max_distance = config.relative_attention_max_distance
355
        self.d_model = config.d_model
356
        self.key_value_proj_dim = config.d_kv
thomwolf's avatar
thomwolf committed
357
358
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
359
        self.inner_dim = self.n_heads * self.key_value_proj_dim
thomwolf's avatar
thomwolf committed
360

361
        # Mesh TensorFlow initialization to avoid scaling before softmax
362
363
364
365
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
thomwolf's avatar
thomwolf committed
366

thomwolf's avatar
thomwolf committed
367
368
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
thomwolf's avatar
thomwolf committed
369
        self.pruned_heads = set()
370
        self.gradient_checkpointing = False
thomwolf's avatar
thomwolf committed
371
372
373
374

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
375
376
377
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
thomwolf's avatar
thomwolf committed
378
379
380
381
382
383
384
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
385
        self.inner_dim = self.key_value_proj_dim * self.n_heads
thomwolf's avatar
thomwolf committed
386
387
388
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
389
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
thomwolf's avatar
thomwolf committed
390
391
392
393
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

Sylvain Gugger's avatar
Sylvain Gugger committed
394
395
396
397
398
399
400
        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

thomwolf's avatar
thomwolf committed
401
402
403
404
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
405
            max_distance: an integer
Sylvain Gugger's avatar
Sylvain Gugger committed
406

thomwolf's avatar
thomwolf committed
407
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
408
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
thomwolf's avatar
thomwolf committed
409
        """
410
        relative_buckets = 0
thomwolf's avatar
thomwolf committed
411
412
        if bidirectional:
            num_buckets //= 2
413
414
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
thomwolf's avatar
thomwolf committed
415
        else:
416
417
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)
thomwolf's avatar
thomwolf committed
418
419
420

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
421
        is_small = relative_position < max_exact
thomwolf's avatar
thomwolf committed
422
423

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
424
        relative_position_if_large = max_exact + (
425
426
427
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
428
        ).to(torch.long)
429
430
        relative_position_if_large = torch.min(
            relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
431
        )
thomwolf's avatar
thomwolf committed
432

433
        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
434
        return relative_buckets
thomwolf's avatar
thomwolf committed
435

436
    def compute_bias(self, query_length, key_length, device=None):
Patrick von Platen's avatar
Patrick von Platen committed
437
        """Compute binned relative position bias"""
438
439
440
441
        if device is None:
            device = self.relative_attention_bias.weight.device
        context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
442
443
444
445
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
446
            num_buckets=self.relative_attention_num_buckets,
447
            max_distance=self.relative_attention_max_distance,
448
        )
449
450
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
thomwolf's avatar
thomwolf committed
451
452
        return values

453
454
    def forward(
        self,
455
        hidden_states,
456
        mask=None,
457
        key_value_states=None,
458
        position_bias=None,
459
        past_key_value=None,
460
        layer_head_mask=None,
461
        query_length=None,
462
        use_cache=False,
463
        output_attentions=False,
464
    ):
thomwolf's avatar
thomwolf committed
465
        """
466
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
thomwolf's avatar
thomwolf committed
467
        """
468
469
470
471
472
473
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length
474

475
        if past_key_value is not None:
476
477
478
479
            if len(past_key_value) != 2:
                raise ValueError(
                    f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
                )
480
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
481

482
        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
thomwolf's avatar
thomwolf committed
483

484
        def shape(states):
Patrick von Platen's avatar
Patrick von Platen committed
485
            """projection"""
486
487
488
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
Patrick von Platen's avatar
Patrick von Platen committed
489
            """reshape"""
490
491
492
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
Patrick von Platen's avatar
Patrick von Platen committed
493
            """projects hidden states correctly to key/query states"""
494
495
496
497
498
499
500
501
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
502

503
504
505
506
507
            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
508
509
510
511
512
513
                elif past_key_value.shape[2] != key_value_states.shape[1]:
                    # checking that the `sequence_length` of the `past_key_value` is the same as
                    # the provided `key_value_states` to support prefix tuning
                    # cross-attn
                    # (batch_size, n_heads, seq_length, dim_per_head)
                    hidden_states = shape(proj_layer(key_value_states))
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
thomwolf's avatar
thomwolf committed
529

530
        # compute scores
Abel's avatar
Abel committed
531
        scores = torch.matmul(
532
533
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
thomwolf's avatar
thomwolf committed
534
535

        if position_bias is None:
thomwolf's avatar
thomwolf committed
536
            if not self.has_relative_attention_bias:
537
538
539
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
540
                if self.gradient_checkpointing and self.training:
541
                    position_bias.requires_grad = True
542
            else:
543
                position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
544
545
546

            # if key and values are already calculated
            # we want only the last query position bias
547
            if past_key_value is not None:
548
                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
549

thomwolf's avatar
thomwolf committed
550
            if mask is not None:
551
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
552

553
554
555
556
557
558
559
560
        if self.pruned_heads:
            mask = torch.ones(position_bias.shape[1])
            mask[list(self.pruned_heads)] = 0
            position_bias_masked = position_bias[:, mask.bool()]
        else:
            position_bias_masked = position_bias

        scores += position_bias_masked
561
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
562
563
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
564
        attn_weights = nn.functional.dropout(
565
566
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
567
568

        # Mask heads if we want to
569
570
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask
thomwolf's avatar
thomwolf committed
571

572
573
        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)
thomwolf's avatar
thomwolf committed
574

575
576
        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
577

578
        if output_attentions:
579
            outputs = outputs + (attn_weights,)
thomwolf's avatar
thomwolf committed
580
        return outputs
thomwolf's avatar
thomwolf committed
581
582
583


class T5LayerSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
584
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
585
        super().__init__()
586
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
587
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
588
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
589

590
    def forward(
591
592
593
594
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
595
        layer_head_mask=None,
596
        past_key_value=None,
597
        use_cache=False,
598
        output_attentions=False,
599
    ):
600
        normed_hidden_states = self.layer_norm(hidden_states)
601
        attention_output = self.SelfAttention(
602
            normed_hidden_states,
603
604
            mask=attention_mask,
            position_bias=position_bias,
605
            layer_head_mask=layer_head_mask,
606
            past_key_value=past_key_value,
607
            use_cache=use_cache,
608
            output_attentions=output_attentions,
609
        )
610
611
        hidden_states = hidden_states + self.dropout(attention_output[0])
        outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them
thomwolf's avatar
thomwolf committed
612
        return outputs
thomwolf's avatar
thomwolf committed
613
614


thomwolf's avatar
thomwolf committed
615
class T5LayerCrossAttention(nn.Module):
616
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
617
        super().__init__()
618
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
thomwolf's avatar
thomwolf committed
619
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
620
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
621

622
623
624
    def forward(
        self,
        hidden_states,
625
        key_value_states,
626
627
        attention_mask=None,
        position_bias=None,
628
        layer_head_mask=None,
629
        past_key_value=None,
630
        use_cache=False,
631
        query_length=None,
632
        output_attentions=False,
633
    ):
634
        normed_hidden_states = self.layer_norm(hidden_states)
635
        attention_output = self.EncDecAttention(
636
            normed_hidden_states,
637
            mask=attention_mask,
638
            key_value_states=key_value_states,
639
            position_bias=position_bias,
640
            layer_head_mask=layer_head_mask,
641
            past_key_value=past_key_value,
642
            use_cache=use_cache,
643
            query_length=query_length,
644
            output_attentions=output_attentions,
645
        )
646
        layer_output = hidden_states + self.dropout(attention_output[0])
thomwolf's avatar
thomwolf committed
647
648
649
650
651
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


class T5Block(nn.Module):
thomwolf's avatar
thomwolf committed
652
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
653
        super().__init__()
thomwolf's avatar
thomwolf committed
654
        self.is_decoder = config.is_decoder
655
656
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
thomwolf's avatar
thomwolf committed
657
        if self.is_decoder:
658
            self.layer.append(T5LayerCrossAttention(config))
659
660

        self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
661

662
663
664
665
666
667
668
669
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
670
        layer_head_mask=None,
671
        cross_attn_layer_head_mask=None,
672
        past_key_value=None,
673
        use_cache=False,
674
        output_attentions=False,
675
        return_dict=True,
676
    ):
677
        if past_key_value is not None:
678
679
            if not self.is_decoder:
                logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
680
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
681

682
683
684
            if len(past_key_value) != expected_num_past_key_values:
                raise ValueError(
                    f"There should be {expected_num_past_key_values} past states. "
685
                    f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
686
687
                    f"Got {len(past_key_value)} past key / value states"
                )
688

689
690
            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
691
        else:
692
            self_attn_past_key_value, cross_attn_past_key_value = None, None
693

694
        self_attention_outputs = self.layer[0](
695
696
697
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
698
            layer_head_mask=layer_head_mask,
699
            past_key_value=self_attn_past_key_value,
700
            use_cache=use_cache,
701
            output_attentions=output_attentions,
702
        )
703
704
705
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

Suraj Patil's avatar
Suraj Patil committed
706
        # clamp inf values to enable fp16 training
707
708
709
710
711
712
        if hidden_states.dtype == torch.float16:
            clamp_value = torch.where(
                torch.isinf(hidden_states).any(),
                torch.finfo(hidden_states.dtype).max - 1000,
                torch.finfo(hidden_states.dtype).max,
            )
Suraj Patil's avatar
Suraj Patil committed
713
714
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

715
716
        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
717
718
719
720
721
722
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None
thomwolf's avatar
thomwolf committed
723

724
725
            cross_attention_outputs = self.layer[1](
                hidden_states,
726
                key_value_states=encoder_hidden_states,
727
728
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
729
                layer_head_mask=cross_attn_layer_head_mask,
730
                past_key_value=cross_attn_past_key_value,
731
                query_length=query_length,
732
                use_cache=use_cache,
733
                output_attentions=output_attentions,
734
            )
thomwolf's avatar
thomwolf committed
735
            hidden_states = cross_attention_outputs[0]
736
737

            # clamp inf values to enable fp16 training
738
739
740
741
742
743
            if hidden_states.dtype == torch.float16:
                clamp_value = torch.where(
                    torch.isinf(hidden_states).any(),
                    torch.finfo(hidden_states.dtype).max - 1000,
                    torch.finfo(hidden_states.dtype).max,
                )
Suraj Patil's avatar
Suraj Patil committed
744
745
                hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

746
747
748
749
750
751
752
753
754
            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)
755
756

        # clamp inf values to enable fp16 training
757
758
759
760
761
762
        if hidden_states.dtype == torch.float16:
            clamp_value = torch.where(
                torch.isinf(hidden_states).any(),
                torch.finfo(hidden_states.dtype).max - 1000,
                torch.finfo(hidden_states.dtype).max,
            )
Suraj Patil's avatar
Suraj Patil committed
763
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
764

765
        outputs = (hidden_states,)
thomwolf's avatar
thomwolf committed
766

767
768
769
770
771
        if use_cache:
            outputs = outputs + (present_key_value_state,) + attention_outputs
        else:
            outputs = outputs + attention_outputs

772
        return outputs  # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
thomwolf's avatar
thomwolf committed
773
774


775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
class T5ClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config: T5Config):
        super().__init__()
        self.dense = nn.Linear(config.d_model, config.d_model)
        self.dropout = nn.Dropout(p=config.classifier_dropout)
        self.out_proj = nn.Linear(config.d_model, config.num_labels)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


thomwolf's avatar
thomwolf committed
793
class T5PreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
794
795
796
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
thomwolf's avatar
thomwolf committed
797
    """
798

thomwolf's avatar
thomwolf committed
799
800
801
    config_class = T5Config
    load_tf_weights = load_tf_weights_in_t5
    base_model_prefix = "transformer"
802
    is_parallelizable = True
803
    supports_gradient_checkpointing = True
804
    _no_split_modules = ["T5Block"]
805
    _keep_in_fp32_modules = ["wo"]
thomwolf's avatar
thomwolf committed
806

807
808
809
810
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
811
812
        dummy_inputs = {
            "decoder_input_ids": input_ids,
813
            "input_ids": input_ids,
814
815
            "decoder_attention_mask": input_mask,
        }
816
817
        return dummy_inputs

thomwolf's avatar
thomwolf committed
818
    def _init_weights(self, module):
Patrick von Platen's avatar
Patrick von Platen committed
819
        """Initialize the weights"""
820
        factor = self.config.initializer_factor  # Used for testing weights initialization
thomwolf's avatar
thomwolf committed
821
        if isinstance(module, T5LayerNorm):
822
            module.weight.data.fill_(factor * 1.0)
823
824
825
826
        elif isinstance(
            module,
            (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering),
        ):
827
828
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
829
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
830
831
            if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
                module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
832
833
834
            if hasattr(module, "qa_outputs"):
                module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
                module.qa_outputs.bias.data.zero_()
835
836
837
838
839
840
841
        elif isinstance(module, T5ClassificationHead):
            module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.dense, "bias") and module.dense.bias is not None:
                module.dense.bias.data.zero_()
            module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
                module.out_proj.bias.data.zero_()
DanielHesslow's avatar
DanielHesslow committed
842
        elif isinstance(module, T5DenseActDense):
843
844
845
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
846
847
            module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi, "bias") and module.wi.bias is not None:
848
                module.wi.bias.data.zero_()
849
850
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
851
                module.wo.bias.data.zero_()
DanielHesslow's avatar
DanielHesslow committed
852
        elif isinstance(module, T5DenseGatedActDense):
Patrick von Platen's avatar
Patrick von Platen committed
853
854
855
856
857
858
859
860
861
            module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
                module.wi_0.bias.data.zero_()
            module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
                module.wi_1.bias.data.zero_()
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
                module.wo.bias.data.zero_()
862
863
864
865
        elif isinstance(module, T5Attention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
866
            key_value_proj_dim = self.config.d_kv
867
            n_heads = self.config.num_heads
868
            module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
869
870
            module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
            module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
871
            module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
872
            if module.has_relative_attention_bias:
873
                module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
thomwolf's avatar
thomwolf committed
874

875
876
877
878
    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

879
880
        if decoder_start_token_id is None:
            raise ValueError(
881
                "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
882
883
                "See T5 docs for more information."
            )
884
885

        # shift inputs to the right
886
887
888
889
890
891
892
893
        if is_torch_fx_proxy(input_ids):
            # Item assignment is not supported natively for proxies.
            shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
            shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        else:
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
            shifted_input_ids[..., 0] = decoder_start_token_id
894

895
896
        if pad_token_id is None:
            raise ValueError("self.model.config.pad_token_id has to be defined.")
Sylvain Gugger's avatar
Sylvain Gugger committed
897
        # replace possible -100 values in labels by `pad_token_id`
898
899
900
901
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

        return shifted_input_ids

thomwolf's avatar
thomwolf committed
902
903

class T5Stack(T5PreTrainedModel):
904
    def __init__(self, config, embed_tokens=None):
Julien Chaumond's avatar
Julien Chaumond committed
905
        super().__init__(config)
906
907

        self.embed_tokens = embed_tokens
thomwolf's avatar
thomwolf committed
908
909
        self.is_decoder = config.is_decoder

910
911
912
        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
        )
thomwolf's avatar
thomwolf committed
913
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
914
915
        self.dropout = nn.Dropout(config.dropout_rate)

916
917
        # Initialize weights and apply final processing
        self.post_init()
918
919
920
        # Model parallel
        self.model_parallel = False
        self.device_map = None
921
        self.gradient_checkpointing = False
922
923
924

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
925
926
927
928
929
930
931
        warnings.warn(
            "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
            " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
            " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
            " 'block.1': 1, ...}",
            FutureWarning,
        )
932
933
        # Check validity of device_map
        self.device_map = (
934
            get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
        )
        assert_device_map(self.device_map, len(self.block))
        self.model_parallel = True
        self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
        self.last_device = "cuda:" + str(max(self.device_map.keys()))
        # Load onto devices
        for k, v in self.device_map.items():
            for layer in v:
                cuda_device = "cuda:" + str(k)
                self.block[layer] = self.block[layer].to(cuda_device)

        # Set embed_tokens to first layer
        self.embed_tokens = self.embed_tokens.to(self.first_device)
        # Set final layer norm to last device
        self.final_layer_norm = self.final_layer_norm.to(self.last_device)

ivanllt's avatar
ivanllt committed
951
    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
952
    def deparallelize(self):
953
954
955
956
        warnings.warn(
            "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
            FutureWarning,
        )
957
958
959
960
961
962
963
964
965
        self.model_parallel = False
        self.device_map = None
        self.first_device = "cpu"
        self.last_device = "cpu"
        for i in range(len(self.block)):
            self.block[i] = self.block[i].to("cpu")
        self.embed_tokens = self.embed_tokens.to("cpu")
        self.final_layer_norm = self.final_layer_norm.to("cpu")
        torch.cuda.empty_cache()
thomwolf's avatar
thomwolf committed
966

967
968
969
970
971
972
    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

973
974
    def forward(
        self,
975
        input_ids=None,
976
977
978
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
979
        inputs_embeds=None,
980
        head_mask=None,
981
        cross_attn_head_mask=None,
982
        past_key_values=None,
983
        use_cache=None,
984
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
985
        output_hidden_states=None,
986
        return_dict=None,
987
    ):
988
989
990
991
        # Model parallel
        if self.model_parallel:
            torch.cuda.set_device(self.first_device)
            self.embed_tokens = self.embed_tokens.to(self.first_device)
992
        use_cache = use_cache if use_cache is not None else self.config.use_cache
993
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
994
995
996
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
997
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
998

999
        if input_ids is not None and inputs_embeds is not None:
1000
1001
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
Jonathan Chang's avatar
Jonathan Chang committed
1002
                f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
1003
            )
1004
1005
1006
1007
1008
1009
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
1010
            err_msg_prefix = "decoder_" if self.is_decoder else ""
Jonathan Chang's avatar
Jonathan Chang committed
1011
            raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
1012
1013

        if inputs_embeds is None:
1014
1015
            if self.embed_tokens is None:
                raise ValueError("You have to initialize the model with valid token embeddings")
1016
1017
1018
1019
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape

1020
1021
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
1022

1023
        if use_cache is True:
1024
1025
            if not self.is_decoder:
                raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
1026

thomwolf's avatar
thomwolf committed
1027
        if attention_mask is None:
1028
            attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
1029
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
thomwolf's avatar
thomwolf committed
1030
            encoder_seq_length = encoder_hidden_states.shape[1]
1031
1032
1033
            encoder_attention_mask = torch.ones(
                batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
            )
thomwolf's avatar
thomwolf committed
1034

1035
1036
1037
        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)
1038

lexhuismans's avatar
lexhuismans committed
1039
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
thomwolf's avatar
thomwolf committed
1040
        # ourselves in which case we just need to make it broadcastable to all heads.
1041
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
thomwolf's avatar
thomwolf committed
1042

1043
1044
1045
1046
1047
1048
1049
        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
1050
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
thomwolf's avatar
thomwolf committed
1051
1052
        else:
            encoder_extended_attention_mask = None
thomwolf's avatar
thomwolf committed
1053

1054
1055
1056
1057
1058
1059
1060
        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

thomwolf's avatar
thomwolf committed
1061
        # Prepare head mask if needed
1062
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
1063
        cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
1064
1065
1066
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
1067
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
thomwolf's avatar
thomwolf committed
1068
        position_bias = None
thomwolf's avatar
thomwolf committed
1069
        encoder_decoder_position_bias = None
thomwolf's avatar
thomwolf committed
1070

1071
        hidden_states = self.dropout(inputs_embeds)
1072

1073
        for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
1074
            layer_head_mask = head_mask[i]
1075
            cross_attn_layer_head_mask = cross_attn_head_mask[i]
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
            # Model parallel
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                # Ensure that attention_mask is always on the same device as hidden_states
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if position_bias is not None:
                    position_bias = position_bias.to(hidden_states.device)
                if encoder_hidden_states is not None:
                    encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
                if encoder_extended_attention_mask is not None:
                    encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
                if encoder_decoder_position_bias is not None:
                    encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
1090
1091
                if layer_head_mask is not None:
                    layer_head_mask = layer_head_mask.to(hidden_states.device)
1092
1093
                if cross_attn_layer_head_mask is not None:
                    cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
Joseph Liu's avatar
Joseph Liu committed
1094
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
1095
1096
                all_hidden_states = all_hidden_states + (hidden_states,)

1097
            if self.gradient_checkpointing and self.training:
1098
                layer_outputs = self._gradient_checkpointing_func(
1099
                    layer_module.forward,
1100
1101
1102
1103
1104
1105
1106
1107
1108
                    hidden_states,
                    extended_attention_mask,
                    position_bias,
                    encoder_hidden_states,
                    encoder_extended_attention_mask,
                    encoder_decoder_position_bias,
                    layer_head_mask,
                    cross_attn_layer_head_mask,
                    None,  # past_key_value is always None with gradient checkpointing
1109
1110
                    use_cache,
                    output_attentions,
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_extended_attention_mask,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    layer_head_mask=layer_head_mask,
                    cross_attn_layer_head_mask=cross_attn_layer_head_mask,
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

thomwolf's avatar
thomwolf committed
1127
            # layer_outputs is a tuple with:
1128
            # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1129
1130
            if use_cache is False:
                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1131

1132
            hidden_states, present_key_value_state = layer_outputs[:2]
1133

1134
            # We share the position biases between the layers - the first layer store them
1135
1136
            # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
            # (cross-attention position bias), (cross-attention weights)
1137
1138
1139
            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1140
            # append next layer key value states
1141
1142
            if use_cache:
                present_key_value_states = present_key_value_states + (present_key_value_state,)
thomwolf's avatar
thomwolf committed
1143

1144
            if output_attentions:
1145
                all_attentions = all_attentions + (layer_outputs[3],)
1146
                if self.is_decoder:
1147
                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
thomwolf's avatar
thomwolf committed
1148

1149
1150
1151
1152
1153
1154
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

thomwolf's avatar
thomwolf committed
1155
        hidden_states = self.final_layer_norm(hidden_states)
thomwolf's avatar
thomwolf committed
1156
        hidden_states = self.dropout(hidden_states)
thomwolf's avatar
thomwolf committed
1157
1158

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
1159
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
1160
1161
            all_hidden_states = all_hidden_states + (hidden_states,)

1162
        if not return_dict:
1163
1164
            return tuple(
                v
1165
1166
1167
1168
1169
1170
1171
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
1172
1173
                if v is not None
            )
1174
        return BaseModelOutputWithPastAndCrossAttentions(
1175
1176
1177
1178
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
1179
            cross_attentions=all_cross_attentions,
1180
        )
thomwolf's avatar
thomwolf committed
1181
1182


1183
T5_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1184

Sylvain Gugger's avatar
Sylvain Gugger committed
1185
1186
1187
1188
    The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
    Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
    Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
    text-to-text denoising generative setting.
thomwolf's avatar
thomwolf committed
1189

Sylvain Gugger's avatar
Sylvain Gugger committed
1190
1191
1192
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)
Sylvain Gugger's avatar
Sylvain Gugger committed
1193

Sylvain Gugger's avatar
Sylvain Gugger committed
1194
1195
1196
    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.
thomwolf's avatar
thomwolf committed
1197
1198

    Parameters:
1199
        config ([`T5Config`]): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
1200
            Initializing with a config file does not load the weights associated with the model, only the
Sylvain Gugger's avatar
Sylvain Gugger committed
1201
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
thomwolf's avatar
thomwolf committed
1202
1203
1204
"""

T5_INPUTS_DOCSTRING = r"""
Patrick von Platen's avatar
Patrick von Platen committed
1205
    Args:
1206
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1207
1208
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.
Sylvain Gugger's avatar
Sylvain Gugger committed
1209

Sylvain Gugger's avatar
Sylvain Gugger committed
1210
            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
Sylvain Gugger's avatar
Sylvain Gugger committed
1211
            [`PreTrainedTokenizer.__call__`] for detail.
Sylvain Gugger's avatar
Sylvain Gugger committed
1212

1213
            [What are input IDs?](../glossary#input-ids)
Sylvain Gugger's avatar
Sylvain Gugger committed
1214

1215
1216
1217
            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Sylvain Gugger's avatar
Sylvain Gugger committed
1218
1219

            - 1 for tokens that are **not masked**,
1220
            - 0 for tokens that are **masked**.
Sylvain Gugger's avatar
Sylvain Gugger committed
1221

1222
1223
            [What are attention masks?](../glossary#attention-mask)
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1224
1225
            Indices of decoder input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
1226
            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
Sylvain Gugger's avatar
Sylvain Gugger committed
1227
            [`PreTrainedTokenizer.__call__`] for details.
1228

1229
            [What are decoder input IDs?](../glossary#decoder-input-ids)
1230

Sylvain Gugger's avatar
Sylvain Gugger committed
1231
1232
            T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1233

Sylvain Gugger's avatar
Sylvain Gugger committed
1234
1235
            To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
            Training](./t5#training).
1236
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1237
1238
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
1239
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1240
1241
            Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
            1]`:
1242
1243
1244
1245

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1246
        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1247
1248
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
            1]`:
1249
1250
1251
1252

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1253
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1254
                Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1255
                `[0, 1]`:
1256
1257
1258
1259

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

1260
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1261
1262
1263
            Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
            the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1264
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1265
1266
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

Sylvain Gugger's avatar
Sylvain Gugger committed
1267
1268
1269
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1270
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1271
1272
1273
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
1274
1275
        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
Sylvain Gugger's avatar
Sylvain Gugger committed
1276
1277
            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
            input (see `past_key_values`). This is useful if you want more control over how to convert
1278
            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
Sylvain Gugger's avatar
Sylvain Gugger committed
1279

Sylvain Gugger's avatar
Sylvain Gugger committed
1280
1281
            If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
            of `inputs_embeds`.
Sylvain Gugger's avatar
Sylvain Gugger committed
1282

1283
        use_cache (`bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1284
1285
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1286

1287
1288
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
Sylvain Gugger's avatar
Sylvain Gugger committed
1289
            tensors for more detail.
1290
1291
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
Sylvain Gugger's avatar
Sylvain Gugger committed
1292
            more detail.
1293
        return_dict (`bool`, *optional*):
1294
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
thomwolf's avatar
thomwolf committed
1295
1296
"""

1297
1298
T5_ENCODER_INPUTS_DOCSTRING = r"""
    Args:
1299
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1300
1301
1302
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.

Sylvain Gugger's avatar
Sylvain Gugger committed
1303
            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
Sylvain Gugger's avatar
Sylvain Gugger committed
1304
            [`PreTrainedTokenizer.__call__`] for detail.
1305

1306
1307
1308
            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1309
1310
1311
1312

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

1313
1314
1315
            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1316
1317
1318
1319

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1320
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1321
1322
1323
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
1324
1325
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1326
            tensors for more detail.
1327
1328
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1329
            more detail.
1330
        return_dict (`bool`, *optional*):
1331
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1332
1333
"""

1334
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1335
1336
1337
1338
1339
1340
1341
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""

1342
1343

@add_start_docstrings(
NielsRogge's avatar
NielsRogge committed
1344
    "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1345
1346
    T5_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1347
class T5Model(T5PreTrainedModel):
Patrick von Platen's avatar
Patrick von Platen committed
1348
    _keys_to_ignore_on_load_unexpected = [
Sylvain Gugger's avatar
Sylvain Gugger committed
1349
        "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1350
    ]
Sylvain Gugger's avatar
Sylvain Gugger committed
1351
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1352

1353
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
1354
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1355
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
thomwolf's avatar
thomwolf committed
1356
1357

        encoder_config = copy.deepcopy(config)
1358
        encoder_config.is_decoder = False
1359
        encoder_config.use_cache = False
1360
        encoder_config.is_encoder_decoder = False
1361
        self.encoder = T5Stack(encoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1362

thomwolf's avatar
thomwolf committed
1363
1364
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1365
        decoder_config.is_encoder_decoder = False
1366
        decoder_config.num_layers = config.num_decoder_layers
1367
        self.decoder = T5Stack(decoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1368

1369
1370
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1371

1372
1373
1374
1375
1376
1377
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
1378
1379
1380
1381
1382
1383
1384
        warnings.warn(
            "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
            " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
            " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
            " 0, 'encoder.block.1': 1, ...}",
            FutureWarning,
        )
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
1397
1398
1399
1400
        warnings.warn(
            "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
            FutureWarning,
        )
1401
1402
1403
1404
1405
1406
1407
1408
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

thomwolf's avatar
thomwolf committed
1409
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
1410
        return self.shared
thomwolf's avatar
thomwolf committed
1411
1412

    def set_input_embeddings(self, new_embeddings):
thomwolf's avatar
thomwolf committed
1413
        self.shared = new_embeddings
1414
1415
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
thomwolf's avatar
thomwolf committed
1416

1417
1418
1419
1420
1421
    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

1422
1423
1424
1425
1426
1427
    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

thomwolf's avatar
thomwolf committed
1428
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
1429
1430
1431
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
thomwolf's avatar
thomwolf committed
1432
1433
1434
1435
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

1436
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1437
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1438
1439
    def forward(
        self,
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        decoder_inputs_embeds: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
Patrick von Platen's avatar
Patrick von Platen committed
1456
        r"""
Lysandre's avatar
Lysandre committed
1457
        Returns:
Patrick von Platen's avatar
Patrick von Platen committed
1458

1459
        Example:
1460

1461
        ```python
Sylvain Gugger's avatar
Sylvain Gugger committed
1462
        >>> from transformers import AutoTokenizer, T5Model
Patrick von Platen's avatar
Patrick von Platen committed
1463

Sylvain Gugger's avatar
Sylvain Gugger committed
1464
        >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
Sylvain Gugger's avatar
Sylvain Gugger committed
1465
        >>> model = T5Model.from_pretrained("t5-small")
Patrick von Platen's avatar
Patrick von Platen committed
1466

Sylvain Gugger's avatar
Sylvain Gugger committed
1467
1468
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1469
        ... ).input_ids  # Batch size 1
1470
        >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
Patrick von Platen's avatar
Patrick von Platen committed
1471

1472
1473
1474
1475
        >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
        >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
        >>> decoder_input_ids = model._shift_right(decoder_input_ids)

1476
1477
1478
1479
        >>> # forward pass
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""
1480
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1481
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
thomwolf's avatar
thomwolf committed
1482

1483
1484
1485
1486
1487
1488
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

thomwolf's avatar
thomwolf committed
1489
        # Encode if needed (training, first prediction pass)
1490
1491
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
1492
1493
1494
1495
1496
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1497
                output_hidden_states=output_hidden_states,
1498
                return_dict=return_dict,
1499
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
1500
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1501
1502
1503
1504
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1505
            )
thomwolf's avatar
thomwolf committed
1506

1507
        hidden_states = encoder_outputs[0]
Kyungmin Lee's avatar
Kyungmin Lee committed
1508

1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
thomwolf's avatar
thomwolf committed
1519

1520
1521
1522
1523
1524
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1525
            past_key_values=past_key_values,
1526
1527
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1528
            head_mask=decoder_head_mask,
1529
            cross_attn_head_mask=cross_attn_head_mask,
1530
            use_cache=use_cache,
1531
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1532
            output_hidden_states=output_hidden_states,
1533
            return_dict=return_dict,
1534
        )
thomwolf's avatar
thomwolf committed
1535

1536
        if not return_dict:
1537
1538
1539
1540
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
1541
            past_key_values=decoder_outputs.past_key_values,
1542
1543
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1544
            cross_attentions=decoder_outputs.cross_attentions,
1545
1546
1547
1548
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1549
1550


Sylvain Gugger's avatar
Sylvain Gugger committed
1551
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
1552
class T5ForConditionalGeneration(T5PreTrainedModel):
Patrick von Platen's avatar
Patrick von Platen committed
1553
    _keys_to_ignore_on_load_unexpected = [
Sylvain Gugger's avatar
Sylvain Gugger committed
1554
        "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1555
    ]
Sylvain Gugger's avatar
Sylvain Gugger committed
1556
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
1557

1558
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
1559
        super().__init__(config)
1560
        self.model_dim = config.d_model
thomwolf's avatar
thomwolf committed
1561

1562
1563
1564
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
1565
        encoder_config.is_decoder = False
1566
        encoder_config.use_cache = False
1567
        encoder_config.is_encoder_decoder = False
1568
        self.encoder = T5Stack(encoder_config, self.shared)
1569
1570
1571

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1572
        decoder_config.is_encoder_decoder = False
1573
        decoder_config.num_layers = config.num_decoder_layers
1574
        self.decoder = T5Stack(decoder_config, self.shared)
1575

1576
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
1577

1578
1579
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1580

1581
1582
1583
1584
1585
1586
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
1587
1588
1589
1590
1591
1592
1593
        warnings.warn(
            "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
            " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
            " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
            " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
            FutureWarning,
        )
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.decoder.first_device)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
1607
1608
1609
1610
        warnings.warn(
            "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
            FutureWarning,
        )
1611
1612
1613
1614
1615
1616
1617
1618
1619
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1620
1621
1622
1623
1624
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
1625
1626
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
1627

1628
1629
1630
1631
1632
    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

1633
1634
1635
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

thomwolf's avatar
thomwolf committed
1636
1637
1638
    def get_output_embeddings(self):
        return self.lm_head

1639
1640
    def get_encoder(self):
        return self.encoder
thomwolf's avatar
thomwolf committed
1641

1642
1643
1644
    def get_decoder(self):
        return self.decoder

1645
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1646
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1647
1648
    def forward(
        self,
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
Patrick von Platen's avatar
Patrick von Platen committed
1666
        r"""
1667
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1668
1669
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1670
            labels in `[0, ..., config.vocab_size]`
Lysandre's avatar
Lysandre committed
1671
1672
1673

        Returns:

1674
        Examples:
Lysandre's avatar
Lysandre committed
1675

1676
        ```python
Sylvain Gugger's avatar
Sylvain Gugger committed
1677
        >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
Lysandre's avatar
Lysandre committed
1678

Sylvain Gugger's avatar
Sylvain Gugger committed
1679
        >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
Sylvain Gugger's avatar
Sylvain Gugger committed
1680
        >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1681
1682

        >>> # training
Sylvain Gugger's avatar
Sylvain Gugger committed
1683
1684
        >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
        >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1685
1686
1687
1688
1689
        >>> outputs = model(input_ids=input_ids, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits

        >>> # inference
Sylvain Gugger's avatar
Sylvain Gugger committed
1690
1691
        >>> input_ids = tokenizer(
        ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1692
        ... ).input_ids  # Batch size 1
1693
1694
1695
1696
        >>> outputs = model.generate(input_ids)
        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        >>> # studies have shown that owning a dog is good for you.
        ```"""
1697
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1698
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1699

1700
1701
1702
1703
1704
1705
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

1706
        # Encode if needed (training, first prediction pass)
1707
        if encoder_outputs is None:
thomwolf's avatar
thomwolf committed
1708
            # Convert encoder inputs in embeddings if needed
1709
            encoder_outputs = self.encoder(
1710
1711
1712
1713
1714
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1715
                output_hidden_states=output_hidden_states,
1716
                return_dict=return_dict,
1717
            )
1718
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1719
1720
1721
1722
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1723
            )
thomwolf's avatar
thomwolf committed
1724

1725
        hidden_states = encoder_outputs[0]
1726

1727
1728
1729
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

Sylvain Gugger's avatar
Sylvain Gugger committed
1730
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1731
            # get decoder inputs from shifting lm labels to the right
Sylvain Gugger's avatar
Sylvain Gugger committed
1732
            decoder_input_ids = self._shift_right(labels)
1733

1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

1745
        # Decode
1746
1747
1748
1749
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1750
            past_key_values=past_key_values,
1751
1752
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1753
            head_mask=decoder_head_mask,
1754
            cross_attn_head_mask=cross_attn_head_mask,
1755
            use_cache=use_cache,
1756
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1757
            output_hidden_states=output_hidden_states,
1758
            return_dict=return_dict,
1759
        )
1760
1761

        sequence_output = decoder_outputs[0]
Patrick von Platen's avatar
Patrick von Platen committed
1762

1763
1764
1765
1766
1767
1768
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

Patrick von Platen's avatar
Patrick von Platen committed
1769
1770
1771
        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1772
            sequence_output = sequence_output * (self.model_dim**-0.5)
Patrick von Platen's avatar
Patrick von Platen committed
1773

thomwolf's avatar
thomwolf committed
1774
        lm_logits = self.lm_head(sequence_output)
thomwolf's avatar
thomwolf committed
1775

1776
        loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1777
        if labels is not None:
Lysandre's avatar
Lysandre committed
1778
            loss_fct = CrossEntropyLoss(ignore_index=-100)
1779
1780
            # move labels to correct device to enable PP
            labels = labels.to(lm_logits.device)
Sylvain Gugger's avatar
Sylvain Gugger committed
1781
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1782
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
thomwolf's avatar
thomwolf committed
1783

1784
        if not return_dict:
1785
1786
1787
1788
1789
1790
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
1791
            past_key_values=decoder_outputs.past_key_values,
1792
1793
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1794
            cross_attentions=decoder_outputs.cross_attentions,
1795
1796
1797
1798
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
1799

1800
    def prepare_inputs_for_generation(
1801
1802
        self,
        input_ids,
1803
        past_key_values=None,
1804
1805
1806
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
1807
        decoder_attention_mask=None,
1808
1809
1810
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
1811
        **kwargs,
1812
    ):
1813
        # cut decoder_input_ids if past_key_values is used
1814
        if past_key_values is not None:
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]
1825

1826
1827
        return {
            "decoder_input_ids": input_ids,
1828
            "past_key_values": past_key_values,
1829
1830
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
1831
1832
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
1833
            "decoder_attention_mask": decoder_attention_mask,
1834
            "cross_attn_head_mask": cross_attn_head_mask,
1835
            "use_cache": use_cache,
1836
1837
        }

1838
1839
1840
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

1841
    def _reorder_cache(self, past_key_values, beam_idx):
1842
1843
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
1844
        if past_key_values is None:
1845
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1846
            return past_key_values
1847
1848

        reordered_decoder_past = ()
1849
        for layer_past_states in past_key_values:
1850
1851
1852
1853
1854
1855
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
1856
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1857
1858
                )

1859
1860
1861
1862
1863
1864
1865
1866
            if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
                raise ValueError(
                    f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
                )
            if len(reordered_layer_past_states) != len(layer_past_states):
                raise ValueError(
                    f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
                )
1867
1868

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1869
        return reordered_decoder_past
1870
1871
1872


@add_start_docstrings(
1873
    "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1874
1875
1876
    T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
1877
    _tied_weights_keys = ["encoder.embed_tokens.weight"]
1878
    _keys_to_ignore_on_load_unexpected = [r"decoder"]
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888

    def __init__(self, config: T5Config):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

1889
1890
        # Initialize weights and apply final processing
        self.post_init()
1891

Lysandre Debut's avatar
Lysandre Debut committed
1892
1893
1894
1895
        # Model parallel
        self.model_parallel = False
        self.device_map = None

1896
1897
    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
1898
1899
1900
1901
1902
1903
1904
        warnings.warn(
            "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
            " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
            " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
            " 'block.1': 1, ...}",
            FutureWarning,
        )
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
1916
1917
1918
1919
        warnings.warn(
            "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
            FutureWarning,
        )
1920
1921
1922
1923
1924
1925
        self.encoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1926
1927
1928
1929
1930
1931
1932
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)

1933
1934
1935
1936
    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)

1937
1938
1939
1940
1941
1942
1943
1944
1945
    def get_encoder(self):
        return self.encoder

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
1946
            self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
1947
1948
1949
1950
1951

    @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
1952
1953
1954
1955
1956
1957
1958
1959
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
1960
1961
1962
        r"""
        Returns:

1963
        Example:
1964

1965
        ```python
Sylvain Gugger's avatar
Sylvain Gugger committed
1966
        >>> from transformers import AutoTokenizer, T5EncoderModel
Sylvain Gugger's avatar
Sylvain Gugger committed
1967

Sylvain Gugger's avatar
Sylvain Gugger committed
1968
        >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
Sylvain Gugger's avatar
Sylvain Gugger committed
1969
1970
1971
        >>> model = T5EncoderModel.from_pretrained("t5-small")
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1972
        ... ).input_ids  # Batch size 1
1973
1974
1975
        >>> outputs = model(input_ids=input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        return encoder_outputs
1989
1990


1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
@add_start_docstrings(
    """
    T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
    tasks.
    """,
    T5_START_DOCSTRING,
)
class T5ForSequenceClassification(T5PreTrainedModel):
    _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: T5Config):
        super().__init__(config)
        self.transformer = T5Model(config)
        self.classification_head = T5ClassificationHead(config)

        # Initialize weights and apply final processing
        self.post_init()

        self.model_parallel = False

    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Returns:
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if labels is not None:
            use_cache = False

        if input_ids is None and inputs_embeds is not None:
            raise NotImplementedError(
                f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
            )

        # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
        # decoder_input_ids from input_ids if no decoder_input_ids are provided
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
                    "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
                    "passed, `input_ids` cannot be `None`. Please pass either "
                    "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
                )
            decoder_input_ids = self._shift_right(input_ids)

        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]

        eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)

        if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
            raise ValueError("All examples must have the same number of <eos> tokens.")
        batch_size, _, hidden_size = sequence_output.shape
        sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
        logits = self.classification_head(sentence_representation)

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            if self.config.problem_type is None:
                if self.config.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.config.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )


2124
2125
2126
2127
2128
2129
2130
2131
@add_start_docstrings(
    """
    T5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
    on top of the hidden-states output to compute `span start logits` and `span end logits`).
    """,
    T5_START_DOCSTRING,
)
class T5ForQuestionAnswering(T5PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
2132
    _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: T5Config):
        super().__init__(config)
        self.model_dim = config.d_model

        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        self.decoder = T5Stack(decoder_config, self.shared)

        self.num_labels = config.num_labels
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

2159
2160
        self.model_parallel = False

2161
2162
2163
2164
2165
2166
2167
2168
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

2169
2170
2171
2172
2173
    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        start_positions: Optional[torch.LongTensor] = None,
        end_positions: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
            are not taken into account for computing the loss.
        Returns:
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        if start_positions is not None and end_positions is not None:
            use_cache = False

        # Copied from models.bart.modeling_bart.BartModel.forward
        #   different to other models, T5 automatically creates decoder_input_ids from
        #   input_ids if no decoder_input_ids are provided
        if decoder_input_ids is None and decoder_inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
                    "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
                    "passed, `input_ids` cannot be `None`. Please pass either "
                    "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
                )
            decoder_input_ids = self._shift_right(input_ids)

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=None,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1).to(start_logits.device)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1).to(end_logits.device)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
            return ((total_loss,) + output) if total_loss is not None else output

        return Seq2SeqQuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )