modeling_t5.py 80.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" PyTorch T5 model."""
thomwolf's avatar
thomwolf committed
16
17


Aymeric Augustin's avatar
Aymeric Augustin committed
18
import copy
thomwolf's avatar
thomwolf committed
19
20
import math
import os
21
import warnings
22
from typing import Optional, Tuple, Union
thomwolf's avatar
thomwolf committed
23
24

import torch
Aymeric Augustin's avatar
Aymeric Augustin committed
25
from torch import nn
26
from torch.nn import CrossEntropyLoss
27
from torch.utils.checkpoint import checkpoint
thomwolf's avatar
thomwolf committed
28

Patrick von Platen's avatar
Patrick von Platen committed
29
from ...activations import ACT2FN
Sylvain Gugger's avatar
Sylvain Gugger committed
30
from ...modeling_outputs import (
31
32
33
34
35
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
36
from ...modeling_utils import PreTrainedModel
37
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
38
39
40
41
42
43
44
45
46
from ...utils import (
    DUMMY_INPUTS,
    DUMMY_MASK,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_torch_fx_proxy,
    logging,
    replace_return_docstrings,
)
47
from ...utils.model_parallel_utils import assert_device_map, get_device_map
Sylvain Gugger's avatar
Sylvain Gugger committed
48
from .configuration_t5 import T5Config
Aymeric Augustin's avatar
Aymeric Augustin committed
49

thomwolf's avatar
thomwolf committed
50

Lysandre Debut's avatar
Lysandre Debut committed
51
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
52

53
_CONFIG_FOR_DOC = "T5Config"
54
_TOKENIZER_FOR_DOC = "T5Tokenizer"
55
_CHECKPOINT_FOR_DOC = "t5-small"
56

thomwolf's avatar
thomwolf committed
57
####################################################
58
# This dict contains ids and associated url
thomwolf's avatar
thomwolf committed
59
60
# for the pretrained weights provided with the models
####################################################
61
62
63
64
65
66
67
68
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "t5-small",
    "t5-base",
    "t5-large",
    "t5-3b",
    "t5-11b",
    # See all T5 models at https://huggingface.co/models?filter=t5
]
thomwolf's avatar
thomwolf committed
69

70

thomwolf's avatar
thomwolf committed
71
72
73
74
75
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
Lysandre's avatar
Lysandre committed
76
    """Load tf checkpoints in a pytorch model."""
thomwolf's avatar
thomwolf committed
77
78
    try:
        import re
79

thomwolf's avatar
thomwolf committed
80
81
82
        import numpy as np
        import tensorflow as tf
    except ImportError:
83
84
85
86
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
87
88
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
89
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
thomwolf's avatar
thomwolf committed
90
91
92
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
93
    tf_weights = {}
thomwolf's avatar
thomwolf committed
94
    for name, shape in init_vars:
95
        logger.info(f"Loading TF weight {name} with shape {shape}")
thomwolf's avatar
thomwolf committed
96
97
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
98
        tf_weights[name] = array
thomwolf's avatar
thomwolf committed
99

100
    for txt_name in names:
101
        name = txt_name.split("/")
thomwolf's avatar
thomwolf committed
102
103
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
104
105
106
107
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
108
            logger.info(f"Skipping {'/'.join(name)}")
109
110
            tf_weights.pop(txt_name, None)
            continue
111
        if "_slot_" in name[-1]:
112
            logger.info(f"Skipping {'/'.join(name)}")
113
            tf_weights.pop(txt_name, None)
thomwolf's avatar
thomwolf committed
114
115
            continue
        pointer = model
116
        array = tf_weights[txt_name]
Patrick von Platen's avatar
Patrick von Platen committed
117

thomwolf's avatar
thomwolf committed
118
        for m_name in name:
119
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
120
                scope_names = re.split(r"_(\d+)", m_name)
thomwolf's avatar
thomwolf committed
121
            else:
122
123
                scope_names = [m_name]
            if scope_names[0] in ["kernel", "scale", "embedding"]:
124
                pointer = getattr(pointer, "weight")
Patrick von Platen's avatar
Patrick von Platen committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            elif scope_names[0] == "self_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[0]
            elif scope_names[0] == "enc_dec_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[1]
            elif scope_names[0] == "dense_relu_dense":
                pointer = getattr(pointer, "layer")
                pointer = pointer[2]
            elif scope_names[0] == "rms_norm":
                if hasattr(pointer, "layer_norm"):
                    pointer = getattr(pointer, "layer_norm")
                elif hasattr(pointer, "final_layer_norm"):
                    pointer = getattr(pointer, "final_layer_norm")
139
140
141
142
143
144
            elif scope_names[0] == "scale":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
Patrick von Platen's avatar
Patrick von Platen committed
145
146
147
148
            elif scope_names[0] == "decoder" and name[1] == "logits":
                continue
            elif scope_names[0] == "logits":
                pointer = getattr(pointer, "lm_head")
Patrick von Platen's avatar
Patrick von Platen committed
149
150
151
            elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
                pointer = getattr(pointer, f"wi_{scope_names[1]}")
                continue
thomwolf's avatar
thomwolf committed
152
153
            else:
                try:
154
                    pointer = getattr(pointer, scope_names[0])
thomwolf's avatar
thomwolf committed
155
                except AttributeError:
156
                    logger.info(f"Skipping {'/'.join(name)}")
thomwolf's avatar
thomwolf committed
157
                    continue
158
159
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
160
                pointer = pointer[num]
161
        if scope_names[0] not in ["kernel", "scale", "embedding"]:
162
            pointer = getattr(pointer, "weight")
163
        if scope_names[0] != "embedding":
164
            logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
thomwolf's avatar
thomwolf committed
165
166
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
167
168
169
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
thomwolf's avatar
thomwolf committed
170
171
172
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
173
        logger.info(f"Initialize PyTorch weight {name}")
174
175
176
        pointer.data = torch.from_numpy(array.astype(np.float32))
        tf_weights.pop(txt_name, None)

177
    logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
thomwolf's avatar
thomwolf committed
178
179
180
181
182
183
    return model


####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
184
# - PreTrainedModel for the models (it-self a sub-class of nn.Module)
thomwolf's avatar
thomwolf committed
185
####################################################
186
PARALLELIZE_DOCSTRING = r"""
Stas Bekman's avatar
Stas Bekman committed
187
188
    This is an experimental feature and is a subject to change at a moment's notice.

189
190
191
192
    Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
    it will evenly distribute blocks across all devices.

    Args:
193
        device_map (`Dict[int, list]`, optional, defaults to None):
194
195
196
197
198
199
200
201
202
203
204
            A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
            automatically mapped to the first device (for esoteric reasons). That means that the first device should
            have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
            following number of attention modules:

                - t5-small: 6
                - t5-base: 12
                - t5-large: 24
                - t5-3b: 24
                - t5-11b: 24

205
    Example:
206

207
208
    ```python
    # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
Sylvain Gugger's avatar
Sylvain Gugger committed
209
210
211
212
213
214
215
    model = T5ForConditionalGeneration.from_pretrained("t5-3b")
    device_map = {
        0: [0, 1, 2],
        1: [3, 4, 5, 6, 7, 8, 9],
        2: [10, 11, 12, 13, 14, 15, 16],
        3: [17, 18, 19, 20, 21, 22, 23],
    }
216
217
    model.parallelize(device_map)
    ```
218
219
220
221
"""
DEPARALLELIZE_DOCSTRING = r"""
    Moves the model to cpu from a model parallel state.

222
    Example:
223

224
225
    ```python
    # On a 4 GPU machine with t5-3b:
Sylvain Gugger's avatar
Sylvain Gugger committed
226
227
228
229
230
231
232
233
234
    model = T5ForConditionalGeneration.from_pretrained("t5-3b")
    device_map = {
        0: [0, 1, 2],
        1: [3, 4, 5, 6, 7, 8, 9],
        2: [10, 11, 12, 13, 14, 15, 16],
        3: [17, 18, 19, 20, 21, 22, 23],
    }
    model.parallelize(device_map)  # Splits the model across several devices
    model.deparallelize()  # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
235
    ```
236
"""
thomwolf's avatar
thomwolf committed
237

238

thomwolf's avatar
thomwolf committed
239
240
class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
Sylvain Gugger's avatar
Sylvain Gugger committed
241
        """
242
        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
thomwolf's avatar
thomwolf committed
243
        """
Julien Chaumond's avatar
Julien Chaumond committed
244
        super().__init__()
thomwolf's avatar
thomwolf committed
245
246
247
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

248
    def forward(self, hidden_states):
249
250
251
252
253
254

        # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
        # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
        # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
        # half-precision inputs is done in fp32

255
256
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
257

258
259
260
261
        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

262
        return self.weight * hidden_states
thomwolf's avatar
thomwolf committed
263
264


265
266
267
268
269
270
271
272
273
274
275
276
277
try:
    from apex.normalization import FusedRMSNorm

    T5LayerNorm = FusedRMSNorm  # noqa

    logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
except ImportError:
    # using the normal T5LayerNorm
    pass
except Exception:
    logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
    pass

278
279
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)

280

DanielHesslow's avatar
DanielHesslow committed
281
class T5DenseActDense(nn.Module):
282
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
283
        super().__init__()
thomwolf's avatar
thomwolf committed
284
285
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
thomwolf's avatar
thomwolf committed
286
        self.dropout = nn.Dropout(config.dropout_rate)
DanielHesslow's avatar
DanielHesslow committed
287
        self.act = ACT2FN[config.dense_act_fn]
thomwolf's avatar
thomwolf committed
288
289

    def forward(self, hidden_states):
290
        hidden_states = self.wi(hidden_states)
DanielHesslow's avatar
DanielHesslow committed
291
        hidden_states = self.act(hidden_states)
292
293
294
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
295
296


DanielHesslow's avatar
DanielHesslow committed
297
class T5DenseGatedActDense(nn.Module):
298
    def __init__(self, config: T5Config):
Patrick von Platen's avatar
Patrick von Platen committed
299
300
301
302
303
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
DanielHesslow's avatar
DanielHesslow committed
304
        self.act = ACT2FN[config.dense_act_fn]
Patrick von Platen's avatar
Patrick von Platen committed
305
306

    def forward(self, hidden_states):
DanielHesslow's avatar
DanielHesslow committed
307
        hidden_gelu = self.act(self.wi_0(hidden_states))
Patrick von Platen's avatar
Patrick von Platen committed
308
309
310
311
312
313
314
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states


thomwolf's avatar
thomwolf committed
315
class T5LayerFF(nn.Module):
316
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
317
        super().__init__()
DanielHesslow's avatar
DanielHesslow committed
318
319
        if config.is_gated_act:
            self.DenseReluDense = T5DenseGatedActDense(config)
Patrick von Platen's avatar
Patrick von Platen committed
320
        else:
DanielHesslow's avatar
DanielHesslow committed
321
            self.DenseReluDense = T5DenseActDense(config)
Patrick von Platen's avatar
Patrick von Platen committed
322

thomwolf's avatar
thomwolf committed
323
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
324
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
325
326

    def forward(self, hidden_states):
327
328
329
330
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
331
332
333


class T5Attention(nn.Module):
334
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
335
        super().__init__()
thomwolf's avatar
thomwolf committed
336
        self.is_decoder = config.is_decoder
thomwolf's avatar
thomwolf committed
337
        self.has_relative_attention_bias = has_relative_attention_bias
thomwolf's avatar
thomwolf committed
338
        self.relative_attention_num_buckets = config.relative_attention_num_buckets
339
        self.relative_attention_max_distance = config.relative_attention_max_distance
340
        self.d_model = config.d_model
341
        self.key_value_proj_dim = config.d_kv
thomwolf's avatar
thomwolf committed
342
343
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
344
        self.inner_dim = self.n_heads * self.key_value_proj_dim
thomwolf's avatar
thomwolf committed
345

346
        # Mesh TensorFlow initialization to avoid scaling before softmax
347
348
349
350
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
thomwolf's avatar
thomwolf committed
351

thomwolf's avatar
thomwolf committed
352
353
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
thomwolf's avatar
thomwolf committed
354
        self.pruned_heads = set()
355
        self.gradient_checkpointing = False
thomwolf's avatar
thomwolf committed
356
357
358
359

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
360
361
362
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
thomwolf's avatar
thomwolf committed
363
364
365
366
367
368
369
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
370
        self.inner_dim = self.key_value_proj_dim * self.n_heads
thomwolf's avatar
thomwolf committed
371
372
373
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
374
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
thomwolf's avatar
thomwolf committed
375
376
377
378
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

Sylvain Gugger's avatar
Sylvain Gugger committed
379
380
381
382
383
384
385
        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

thomwolf's avatar
thomwolf committed
386
387
388
389
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
390
            max_distance: an integer
Sylvain Gugger's avatar
Sylvain Gugger committed
391

thomwolf's avatar
thomwolf committed
392
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
393
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
thomwolf's avatar
thomwolf committed
394
        """
395
        relative_buckets = 0
thomwolf's avatar
thomwolf committed
396
397
        if bidirectional:
            num_buckets //= 2
398
399
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
thomwolf's avatar
thomwolf committed
400
        else:
401
402
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)
thomwolf's avatar
thomwolf committed
403
404
405

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
406
        is_small = relative_position < max_exact
thomwolf's avatar
thomwolf committed
407
408

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
409
        relative_position_if_large = max_exact + (
410
411
412
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
413
        ).to(torch.long)
414
415
        relative_position_if_large = torch.min(
            relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
416
        )
thomwolf's avatar
thomwolf committed
417

418
        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
419
        return relative_buckets
thomwolf's avatar
thomwolf committed
420

421
    def compute_bias(self, query_length, key_length, device=None):
Patrick von Platen's avatar
Patrick von Platen committed
422
        """Compute binned relative position bias"""
423
424
425
426
        if device is None:
            device = self.relative_attention_bias.weight.device
        context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
427
428
429
430
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
431
            num_buckets=self.relative_attention_num_buckets,
432
            max_distance=self.relative_attention_max_distance,
433
        )
434
435
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
thomwolf's avatar
thomwolf committed
436
437
        return values

438
439
    def forward(
        self,
440
        hidden_states,
441
        mask=None,
442
        key_value_states=None,
443
        position_bias=None,
444
        past_key_value=None,
445
        layer_head_mask=None,
446
        query_length=None,
447
        use_cache=False,
448
        output_attentions=False,
449
    ):
thomwolf's avatar
thomwolf committed
450
        """
451
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
thomwolf's avatar
thomwolf committed
452
        """
453
454
455
456
457
458
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length
459

460
        if past_key_value is not None:
461
            assert (
462
                len(past_key_value) == 2
463
            ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
464
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
465

466
        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
thomwolf's avatar
thomwolf committed
467

468
        def shape(states):
Patrick von Platen's avatar
Patrick von Platen committed
469
            """projection"""
470
471
472
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
Patrick von Platen's avatar
Patrick von Platen committed
473
            """reshape"""
474
475
476
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
Patrick von Platen's avatar
Patrick von Platen committed
477
            """projects hidden states correctly to key/query states"""
478
479
480
481
482
483
484
485
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
486

487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
thomwolf's avatar
thomwolf committed
507

508
        # compute scores
Abel's avatar
Abel committed
509
        scores = torch.matmul(
510
511
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
thomwolf's avatar
thomwolf committed
512
513

        if position_bias is None:
thomwolf's avatar
thomwolf committed
514
            if not self.has_relative_attention_bias:
515
516
517
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
518
                if self.gradient_checkpointing and self.training:
519
                    position_bias.requires_grad = True
520
            else:
521
                position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
522
523
524

            # if key and values are already calculated
            # we want only the last query position bias
525
            if past_key_value is not None:
526
                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
527

thomwolf's avatar
thomwolf committed
528
            if mask is not None:
529
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
530

thomwolf's avatar
thomwolf committed
531
        scores += position_bias
532
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
533
534
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
535
        attn_weights = nn.functional.dropout(
536
537
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
538
539

        # Mask heads if we want to
540
541
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask
thomwolf's avatar
thomwolf committed
542

543
544
        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)
thomwolf's avatar
thomwolf committed
545

546
547
        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
548

549
        if output_attentions:
550
            outputs = outputs + (attn_weights,)
thomwolf's avatar
thomwolf committed
551
        return outputs
thomwolf's avatar
thomwolf committed
552
553
554


class T5LayerSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
555
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
556
        super().__init__()
557
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
558
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
559
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
560

561
    def forward(
562
563
564
565
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
566
        layer_head_mask=None,
567
        past_key_value=None,
568
        use_cache=False,
569
        output_attentions=False,
570
    ):
571
        normed_hidden_states = self.layer_norm(hidden_states)
572
        attention_output = self.SelfAttention(
573
            normed_hidden_states,
574
575
            mask=attention_mask,
            position_bias=position_bias,
576
            layer_head_mask=layer_head_mask,
577
            past_key_value=past_key_value,
578
            use_cache=use_cache,
579
            output_attentions=output_attentions,
580
        )
581
582
        hidden_states = hidden_states + self.dropout(attention_output[0])
        outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them
thomwolf's avatar
thomwolf committed
583
        return outputs
thomwolf's avatar
thomwolf committed
584
585


thomwolf's avatar
thomwolf committed
586
class T5LayerCrossAttention(nn.Module):
587
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
588
        super().__init__()
589
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
thomwolf's avatar
thomwolf committed
590
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
591
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
592

593
594
595
    def forward(
        self,
        hidden_states,
596
        key_value_states,
597
598
        attention_mask=None,
        position_bias=None,
599
        layer_head_mask=None,
600
        past_key_value=None,
601
        use_cache=False,
602
        query_length=None,
603
        output_attentions=False,
604
    ):
605
        normed_hidden_states = self.layer_norm(hidden_states)
606
        attention_output = self.EncDecAttention(
607
            normed_hidden_states,
608
            mask=attention_mask,
609
            key_value_states=key_value_states,
610
            position_bias=position_bias,
611
            layer_head_mask=layer_head_mask,
612
            past_key_value=past_key_value,
613
            use_cache=use_cache,
614
            query_length=query_length,
615
            output_attentions=output_attentions,
616
        )
617
        layer_output = hidden_states + self.dropout(attention_output[0])
thomwolf's avatar
thomwolf committed
618
619
620
621
622
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


class T5Block(nn.Module):
thomwolf's avatar
thomwolf committed
623
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
624
        super().__init__()
thomwolf's avatar
thomwolf committed
625
        self.is_decoder = config.is_decoder
626
627
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
thomwolf's avatar
thomwolf committed
628
        if self.is_decoder:
629
            self.layer.append(T5LayerCrossAttention(config))
630
631

        self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
632

633
634
635
636
637
638
639
640
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
641
        layer_head_mask=None,
642
        cross_attn_layer_head_mask=None,
643
        past_key_value=None,
644
        use_cache=False,
645
        output_attentions=False,
646
        return_dict=True,
647
    ):
648

649
        if past_key_value is not None:
650
651
            if not self.is_decoder:
                logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
652
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
653

654
655
656
            if len(past_key_value) != expected_num_past_key_values:
                raise ValueError(
                    f"There should be {expected_num_past_key_values} past states. "
657
                    f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
658
659
                    f"Got {len(past_key_value)} past key / value states"
                )
660

661
662
            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
663
        else:
664
            self_attn_past_key_value, cross_attn_past_key_value = None, None
665

666
        self_attention_outputs = self.layer[0](
667
668
669
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
670
            layer_head_mask=layer_head_mask,
671
            past_key_value=self_attn_past_key_value,
672
            use_cache=use_cache,
673
            output_attentions=output_attentions,
674
        )
675
676
677
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

Suraj Patil's avatar
Suraj Patil committed
678
        # clamp inf values to enable fp16 training
679
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
680
681
682
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

683
684
        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
685
686
687
688
689
690
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None
thomwolf's avatar
thomwolf committed
691

692
693
            cross_attention_outputs = self.layer[1](
                hidden_states,
694
                key_value_states=encoder_hidden_states,
695
696
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
697
                layer_head_mask=cross_attn_layer_head_mask,
698
                past_key_value=cross_attn_past_key_value,
699
                query_length=query_length,
700
                use_cache=use_cache,
701
                output_attentions=output_attentions,
702
            )
thomwolf's avatar
thomwolf committed
703
            hidden_states = cross_attention_outputs[0]
704
705
706

            # clamp inf values to enable fp16 training
            if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
707
708
709
                clamp_value = torch.finfo(hidden_states.dtype).max - 1000
                hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

710
711
712
713
714
715
716
717
718
            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)
719
720
721

        # clamp inf values to enable fp16 training
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
722
723
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
724

725
        outputs = (hidden_states,)
thomwolf's avatar
thomwolf committed
726

727
728
729
730
731
        if use_cache:
            outputs = outputs + (present_key_value_state,) + attention_outputs
        else:
            outputs = outputs + attention_outputs

732
        return outputs  # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
thomwolf's avatar
thomwolf committed
733
734


thomwolf's avatar
thomwolf committed
735
class T5PreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
736
737
738
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
thomwolf's avatar
thomwolf committed
739
    """
740

thomwolf's avatar
thomwolf committed
741
742
743
    config_class = T5Config
    load_tf_weights = load_tf_weights_in_t5
    base_model_prefix = "transformer"
744
    is_parallelizable = True
745
    supports_gradient_checkpointing = True
746
    _no_split_modules = ["T5Block"]
thomwolf's avatar
thomwolf committed
747

748
749
750
751
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
752
753
        dummy_inputs = {
            "decoder_input_ids": input_ids,
754
            "input_ids": input_ids,
755
756
            "decoder_attention_mask": input_mask,
        }
757
758
        return dummy_inputs

thomwolf's avatar
thomwolf committed
759
    def _init_weights(self, module):
Patrick von Platen's avatar
Patrick von Platen committed
760
        """Initialize the weights"""
761
        factor = self.config.initializer_factor  # Used for testing weights initialization
thomwolf's avatar
thomwolf committed
762
        if isinstance(module, T5LayerNorm):
763
            module.weight.data.fill_(factor * 1.0)
764
        elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
765
766
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
767
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
768
769
            if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
                module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
DanielHesslow's avatar
DanielHesslow committed
770
        elif isinstance(module, T5DenseActDense):
771
772
773
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
774
775
            module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi, "bias") and module.wi.bias is not None:
776
                module.wi.bias.data.zero_()
777
778
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
779
                module.wo.bias.data.zero_()
DanielHesslow's avatar
DanielHesslow committed
780
        elif isinstance(module, T5DenseGatedActDense):
Patrick von Platen's avatar
Patrick von Platen committed
781
782
783
784
785
786
787
788
789
            module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
                module.wi_0.bias.data.zero_()
            module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
                module.wi_1.bias.data.zero_()
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
                module.wo.bias.data.zero_()
790
791
792
793
        elif isinstance(module, T5Attention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
794
            key_value_proj_dim = self.config.d_kv
795
            n_heads = self.config.num_heads
796
            module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
797
798
            module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
            module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
799
            module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
800
            if module.has_relative_attention_bias:
801
                module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
thomwolf's avatar
thomwolf committed
802

803
804
805
806
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (T5Attention, T5Stack)):
            module.gradient_checkpointing = value

807
808
809
810
    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

Sylvain Gugger's avatar
Sylvain Gugger committed
811
812
813
814
        assert decoder_start_token_id is not None, (
            "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
            " See T5 docs for more information"
        )
815
816

        # shift inputs to the right
817
818
819
820
821
822
823
824
        if is_torch_fx_proxy(input_ids):
            # Item assignment is not supported natively for proxies.
            shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
            shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        else:
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
            shifted_input_ids[..., 0] = decoder_start_token_id
825
826

        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
Sylvain Gugger's avatar
Sylvain Gugger committed
827
        # replace possible -100 values in labels by `pad_token_id`
828
829
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

830
        assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
831
832
833

        return shifted_input_ids

thomwolf's avatar
thomwolf committed
834
835

class T5Stack(T5PreTrainedModel):
836
    def __init__(self, config, embed_tokens=None):
Julien Chaumond's avatar
Julien Chaumond committed
837
        super().__init__(config)
838
839

        self.embed_tokens = embed_tokens
thomwolf's avatar
thomwolf committed
840
841
        self.is_decoder = config.is_decoder

842
843
844
        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
        )
thomwolf's avatar
thomwolf committed
845
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
846
847
        self.dropout = nn.Dropout(config.dropout_rate)

848
849
        # Initialize weights and apply final processing
        self.post_init()
850
851
852
        # Model parallel
        self.model_parallel = False
        self.device_map = None
853
        self.gradient_checkpointing = False
854
855
856
857
858

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        # Check validity of device_map
        self.device_map = (
859
            get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
        )
        assert_device_map(self.device_map, len(self.block))
        self.model_parallel = True
        self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
        self.last_device = "cuda:" + str(max(self.device_map.keys()))
        # Load onto devices
        for k, v in self.device_map.items():
            for layer in v:
                cuda_device = "cuda:" + str(k)
                self.block[layer] = self.block[layer].to(cuda_device)

        # Set embed_tokens to first layer
        self.embed_tokens = self.embed_tokens.to(self.first_device)
        # Set final layer norm to last device
        self.final_layer_norm = self.final_layer_norm.to(self.last_device)

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.model_parallel = False
        self.device_map = None
        self.first_device = "cpu"
        self.last_device = "cpu"
        for i in range(len(self.block)):
            self.block[i] = self.block[i].to("cpu")
        self.embed_tokens = self.embed_tokens.to("cpu")
        self.final_layer_norm = self.final_layer_norm.to("cpu")
        torch.cuda.empty_cache()
thomwolf's avatar
thomwolf committed
887

888
889
890
891
892
893
    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

894
895
    def forward(
        self,
896
        input_ids=None,
897
898
899
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
900
        inputs_embeds=None,
901
        head_mask=None,
902
        cross_attn_head_mask=None,
903
        past_key_values=None,
904
        use_cache=None,
905
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
906
        output_hidden_states=None,
907
        return_dict=None,
908
    ):
909
910
911
912
        # Model parallel
        if self.model_parallel:
            torch.cuda.set_device(self.first_device)
            self.embed_tokens = self.embed_tokens.to(self.first_device)
913
        use_cache = use_cache if use_cache is not None else self.config.use_cache
914
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
915
916
917
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
918
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
919

920
        if input_ids is not None and inputs_embeds is not None:
921
922
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
Jonathan Chang's avatar
Jonathan Chang committed
923
                f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
924
            )
925
926
927
928
929
930
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
931
            err_msg_prefix = "decoder_" if self.is_decoder else ""
Jonathan Chang's avatar
Jonathan Chang committed
932
            raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
933
934

        if inputs_embeds is None:
935
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
936
937
938
939
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape

940
941
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
942

943
        if use_cache is True:
Stas Bekman's avatar
Stas Bekman committed
944
            assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
945

thomwolf's avatar
thomwolf committed
946
        if attention_mask is None:
947
948
            attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
thomwolf's avatar
thomwolf committed
949
            encoder_seq_length = encoder_hidden_states.shape[1]
950
951
952
            encoder_attention_mask = torch.ones(
                batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
            )
thomwolf's avatar
thomwolf committed
953

954
955
956
        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)
957

lexhuismans's avatar
lexhuismans committed
958
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
thomwolf's avatar
thomwolf committed
959
        # ourselves in which case we just need to make it broadcastable to all heads.
960
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
thomwolf's avatar
thomwolf committed
961

962
963
964
965
966
967
968
        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
969
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
thomwolf's avatar
thomwolf committed
970
971
        else:
            encoder_extended_attention_mask = None
thomwolf's avatar
thomwolf committed
972
973

        # Prepare head mask if needed
974
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
975
        cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
976
977
978
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
979
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
thomwolf's avatar
thomwolf committed
980
        position_bias = None
thomwolf's avatar
thomwolf committed
981
        encoder_decoder_position_bias = None
thomwolf's avatar
thomwolf committed
982

983
        hidden_states = self.dropout(inputs_embeds)
984

985
        for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
986
            layer_head_mask = head_mask[i]
987
            cross_attn_layer_head_mask = cross_attn_head_mask[i]
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
            # Model parallel
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                # Ensure that attention_mask is always on the same device as hidden_states
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if position_bias is not None:
                    position_bias = position_bias.to(hidden_states.device)
                if encoder_hidden_states is not None:
                    encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
                if encoder_extended_attention_mask is not None:
                    encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
                if encoder_decoder_position_bias is not None:
                    encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
1002
1003
                if layer_head_mask is not None:
                    layer_head_mask = layer_head_mask.to(hidden_states.device)
1004
1005
                if cross_attn_layer_head_mask is not None:
                    cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
Joseph Liu's avatar
Joseph Liu committed
1006
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
1007
1008
                all_hidden_states = all_hidden_states + (hidden_states,)

1009
            if self.gradient_checkpointing and self.training:
1010
                if use_cache:
1011
                    logger.warning(
1012
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return tuple(module(*inputs, use_cache, output_attentions))

                    return custom_forward

                layer_outputs = checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    extended_attention_mask,
                    position_bias,
                    encoder_hidden_states,
                    encoder_extended_attention_mask,
                    encoder_decoder_position_bias,
                    layer_head_mask,
                    cross_attn_layer_head_mask,
                    None,  # past_key_value is always None with gradient checkpointing
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_extended_attention_mask,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    layer_head_mask=layer_head_mask,
                    cross_attn_layer_head_mask=cross_attn_layer_head_mask,
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

thomwolf's avatar
thomwolf committed
1049
            # layer_outputs is a tuple with:
1050
            # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1051
1052
            if use_cache is False:
                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1053

1054
            hidden_states, present_key_value_state = layer_outputs[:2]
1055

1056
            # We share the position biases between the layers - the first layer store them
1057
1058
            # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
            # (cross-attention position bias), (cross-attention weights)
1059
1060
1061
            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1062
            # append next layer key value states
1063
1064
            if use_cache:
                present_key_value_states = present_key_value_states + (present_key_value_state,)
thomwolf's avatar
thomwolf committed
1065

1066
            if output_attentions:
1067
                all_attentions = all_attentions + (layer_outputs[3],)
1068
                if self.is_decoder:
1069
                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
thomwolf's avatar
thomwolf committed
1070

1071
1072
1073
1074
1075
1076
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

thomwolf's avatar
thomwolf committed
1077
        hidden_states = self.final_layer_norm(hidden_states)
thomwolf's avatar
thomwolf committed
1078
        hidden_states = self.dropout(hidden_states)
thomwolf's avatar
thomwolf committed
1079
1080

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
1081
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
1082
1083
            all_hidden_states = all_hidden_states + (hidden_states,)

1084
        if not return_dict:
1085
1086
            return tuple(
                v
1087
1088
1089
1090
1091
1092
1093
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
1094
1095
                if v is not None
            )
1096
        return BaseModelOutputWithPastAndCrossAttentions(
1097
1098
1099
1100
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
1101
            cross_attentions=all_cross_attentions,
1102
        )
thomwolf's avatar
thomwolf committed
1103
1104


1105
T5_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1106

Sylvain Gugger's avatar
Sylvain Gugger committed
1107
1108
1109
1110
    The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
    Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
    Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
    text-to-text denoising generative setting.
thomwolf's avatar
thomwolf committed
1111

Sylvain Gugger's avatar
Sylvain Gugger committed
1112
1113
1114
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)
Sylvain Gugger's avatar
Sylvain Gugger committed
1115

Sylvain Gugger's avatar
Sylvain Gugger committed
1116
1117
1118
    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.
thomwolf's avatar
thomwolf committed
1119
1120

    Parameters:
1121
        config ([`T5Config`]): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
1122
            Initializing with a config file does not load the weights associated with the model, only the
Sylvain Gugger's avatar
Sylvain Gugger committed
1123
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
thomwolf's avatar
thomwolf committed
1124
1125
1126
"""

T5_INPUTS_DOCSTRING = r"""
Patrick von Platen's avatar
Patrick von Platen committed
1127
    Args:
1128
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1129
1130
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.
Sylvain Gugger's avatar
Sylvain Gugger committed
1131

Sylvain Gugger's avatar
Sylvain Gugger committed
1132
1133
            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for detail.
Sylvain Gugger's avatar
Sylvain Gugger committed
1134

1135
            [What are input IDs?](../glossary#input-ids)
Sylvain Gugger's avatar
Sylvain Gugger committed
1136

1137
1138
1139
            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Sylvain Gugger's avatar
Sylvain Gugger committed
1140
1141

            - 1 for tokens that are **not masked**,
1142
            - 0 for tokens that are **masked**.
Sylvain Gugger's avatar
Sylvain Gugger committed
1143

1144
1145
            [What are attention masks?](../glossary#attention-mask)
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1146
1147
            Indices of decoder input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
1148
1149
            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
1150

1151
            [What are decoder input IDs?](../glossary#decoder-input-ids)
1152

Sylvain Gugger's avatar
Sylvain Gugger committed
1153
1154
            T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1155

Sylvain Gugger's avatar
Sylvain Gugger committed
1156
1157
            To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
            Training](./t5#training).
1158
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1159
1160
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
1161
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1162
1163
            Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
            1]`:
1164
1165
1166
1167

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1168
        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1169
1170
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
            1]`:
1171
1172
1173
1174

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1175
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1176
                Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1177
                `[0, 1]`:
1178
1179
1180
1181

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

1182
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1183
1184
1185
            Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
            the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1186
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1187
1188
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

Sylvain Gugger's avatar
Sylvain Gugger committed
1189
1190
1191
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1192
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1193
1194
1195
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
1196
1197
        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
Sylvain Gugger's avatar
Sylvain Gugger committed
1198
1199
            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
            input (see `past_key_values`). This is useful if you want more control over how to convert
1200
            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
Sylvain Gugger's avatar
Sylvain Gugger committed
1201

Sylvain Gugger's avatar
Sylvain Gugger committed
1202
1203
            If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
            of `inputs_embeds`.
Sylvain Gugger's avatar
Sylvain Gugger committed
1204

1205
        use_cache (`bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1206
1207
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1208

1209
1210
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
Sylvain Gugger's avatar
Sylvain Gugger committed
1211
            tensors for more detail.
1212
1213
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
Sylvain Gugger's avatar
Sylvain Gugger committed
1214
            more detail.
1215
        return_dict (`bool`, *optional*):
1216
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
thomwolf's avatar
thomwolf committed
1217
1218
"""

1219
1220
T5_ENCODER_INPUTS_DOCSTRING = r"""
    Args:
1221
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1222
1223
1224
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.

Sylvain Gugger's avatar
Sylvain Gugger committed
1225
1226
            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for detail.
1227

1228
1229
1230
            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1231
1232
1233
1234

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

1235
1236
1237
            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1238
1239
1240
1241

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1242
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1243
1244
1245
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
1246
1247
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1248
            tensors for more detail.
1249
1250
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1251
            more detail.
1252
        return_dict (`bool`, *optional*):
1253
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1254
1255
"""

1256
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1257
1258
1259
1260
1261
1262
1263
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""

1264
1265

@add_start_docstrings(
NielsRogge's avatar
NielsRogge committed
1266
    "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1267
1268
    T5_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1269
class T5Model(T5PreTrainedModel):
1270
    _keys_to_ignore_on_load_missing = [
1271
1272
        r"encoder.embed_tokens.weight",
        r"decoder.embed_tokens.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1273
1274
    ]
    _keys_to_ignore_on_load_unexpected = [
1275
        r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1276
1277
    ]

1278
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
1279
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1280
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
thomwolf's avatar
thomwolf committed
1281
1282

        encoder_config = copy.deepcopy(config)
1283
        encoder_config.is_decoder = False
1284
        encoder_config.use_cache = False
1285
        encoder_config.is_encoder_decoder = False
1286
        self.encoder = T5Stack(encoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1287

thomwolf's avatar
thomwolf committed
1288
1289
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1290
        decoder_config.is_encoder_decoder = False
1291
        decoder_config.num_layers = config.num_decoder_layers
1292
        self.decoder = T5Stack(decoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1293

1294
1295
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1296

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

thomwolf's avatar
thomwolf committed
1323
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
1324
        return self.shared
thomwolf's avatar
thomwolf committed
1325
1326

    def set_input_embeddings(self, new_embeddings):
thomwolf's avatar
thomwolf committed
1327
        self.shared = new_embeddings
1328
1329
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
thomwolf's avatar
thomwolf committed
1330

1331
1332
1333
1334
1335
1336
    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

thomwolf's avatar
thomwolf committed
1337
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
1338
1339
1340
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
thomwolf's avatar
thomwolf committed
1341
1342
1343
1344
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

1345
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1346
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1347
1348
    def forward(
        self,
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        decoder_inputs_embeds: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
Patrick von Platen's avatar
Patrick von Platen committed
1365
        r"""
Lysandre's avatar
Lysandre committed
1366
        Returns:
Patrick von Platen's avatar
Patrick von Platen committed
1367

1368
        Example:
1369

1370
1371
        ```python
        >>> from transformers import T5Tokenizer, T5Model
Patrick von Platen's avatar
Patrick von Platen committed
1372

Sylvain Gugger's avatar
Sylvain Gugger committed
1373
1374
        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5Model.from_pretrained("t5-small")
Patrick von Platen's avatar
Patrick von Platen committed
1375

Sylvain Gugger's avatar
Sylvain Gugger committed
1376
1377
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1378
        ... ).input_ids  # Batch size 1
1379
        >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
Patrick von Platen's avatar
Patrick von Platen committed
1380

1381
1382
1383
1384
        >>> # forward pass
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""
1385
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1386
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
thomwolf's avatar
thomwolf committed
1387

1388
1389
1390
1391
1392
1393
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

thomwolf's avatar
thomwolf committed
1394
        # Encode if needed (training, first prediction pass)
1395
1396
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
1397
1398
1399
1400
1401
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1402
                output_hidden_states=output_hidden_states,
1403
                return_dict=return_dict,
1404
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
1405
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1406
1407
1408
1409
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1410
            )
thomwolf's avatar
thomwolf committed
1411

1412
        hidden_states = encoder_outputs[0]
Kyungmin Lee's avatar
Kyungmin Lee committed
1413

1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
thomwolf's avatar
thomwolf committed
1424

1425
1426
1427
1428
1429
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1430
            past_key_values=past_key_values,
1431
1432
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1433
            head_mask=decoder_head_mask,
1434
            cross_attn_head_mask=cross_attn_head_mask,
1435
            use_cache=use_cache,
1436
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1437
            output_hidden_states=output_hidden_states,
1438
            return_dict=return_dict,
1439
        )
thomwolf's avatar
thomwolf committed
1440

1441
        if not return_dict:
1442
1443
1444
1445
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
1446
            past_key_values=decoder_outputs.past_key_values,
1447
1448
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1449
            cross_attentions=decoder_outputs.cross_attentions,
1450
1451
1452
1453
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1454
1455


Sylvain Gugger's avatar
Sylvain Gugger committed
1456
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
1457
class T5ForConditionalGeneration(T5PreTrainedModel):
1458
    _keys_to_ignore_on_load_missing = [
1459
1460
1461
        r"encoder.embed_tokens.weight",
        r"decoder.embed_tokens.weight",
        r"lm_head.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1462
1463
    ]
    _keys_to_ignore_on_load_unexpected = [
1464
        r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1465
    ]
1466

1467
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
1468
        super().__init__(config)
1469
        self.model_dim = config.d_model
thomwolf's avatar
thomwolf committed
1470

1471
1472
1473
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
1474
        encoder_config.is_decoder = False
1475
        encoder_config.use_cache = False
1476
        encoder_config.is_encoder_decoder = False
1477
        self.encoder = T5Stack(encoder_config, self.shared)
1478
1479
1480

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1481
        decoder_config.is_encoder_decoder = False
1482
        decoder_config.num_layers = config.num_decoder_layers
1483
        self.decoder = T5Stack(decoder_config, self.shared)
1484

1485
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
1486

1487
1488
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1489

1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.decoder.first_device)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1518
1519
1520
1521
1522
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
1523
1524
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
1525

1526
1527
1528
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

thomwolf's avatar
thomwolf committed
1529
1530
1531
    def get_output_embeddings(self):
        return self.lm_head

1532
1533
    def get_encoder(self):
        return self.encoder
thomwolf's avatar
thomwolf committed
1534

1535
1536
1537
    def get_decoder(self):
        return self.decoder

1538
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1539
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1540
1541
    def forward(
        self,
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
Patrick von Platen's avatar
Patrick von Platen committed
1559
        r"""
1560
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1561
1562
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1563
            labels in `[0, ..., config.vocab_size]`
Lysandre's avatar
Lysandre committed
1564
1565
1566

        Returns:

1567
        Examples:
Lysandre's avatar
Lysandre committed
1568

1569
1570
        ```python
        >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
Lysandre's avatar
Lysandre committed
1571

Sylvain Gugger's avatar
Sylvain Gugger committed
1572
1573
        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1574
1575

        >>> # training
Sylvain Gugger's avatar
Sylvain Gugger committed
1576
1577
        >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
        >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1578
1579
1580
1581
1582
        >>> outputs = model(input_ids=input_ids, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits

        >>> # inference
Sylvain Gugger's avatar
Sylvain Gugger committed
1583
1584
        >>> input_ids = tokenizer(
        ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1585
        ... ).input_ids  # Batch size 1
1586
1587
1588
1589
        >>> outputs = model.generate(input_ids)
        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        >>> # studies have shown that owning a dog is good for you.
        ```"""
1590
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1591
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1592

1593
1594
1595
1596
1597
1598
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

1599
        # Encode if needed (training, first prediction pass)
1600
        if encoder_outputs is None:
thomwolf's avatar
thomwolf committed
1601
            # Convert encoder inputs in embeddings if needed
1602
            encoder_outputs = self.encoder(
1603
1604
1605
1606
1607
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1608
                output_hidden_states=output_hidden_states,
1609
                return_dict=return_dict,
1610
            )
1611
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1612
1613
1614
1615
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1616
            )
thomwolf's avatar
thomwolf committed
1617

1618
        hidden_states = encoder_outputs[0]
1619

1620
1621
1622
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

Sylvain Gugger's avatar
Sylvain Gugger committed
1623
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1624
            # get decoder inputs from shifting lm labels to the right
Sylvain Gugger's avatar
Sylvain Gugger committed
1625
            decoder_input_ids = self._shift_right(labels)
1626

1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

1638
        # Decode
1639
1640
1641
1642
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1643
            past_key_values=past_key_values,
1644
1645
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1646
            head_mask=decoder_head_mask,
1647
            cross_attn_head_mask=cross_attn_head_mask,
1648
            use_cache=use_cache,
1649
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1650
            output_hidden_states=output_hidden_states,
1651
            return_dict=return_dict,
1652
        )
1653
1654

        sequence_output = decoder_outputs[0]
Patrick von Platen's avatar
Patrick von Platen committed
1655

1656
1657
1658
1659
1660
1661
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

Patrick von Platen's avatar
Patrick von Platen committed
1662
1663
1664
        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1665
            sequence_output = sequence_output * (self.model_dim**-0.5)
Patrick von Platen's avatar
Patrick von Platen committed
1666

thomwolf's avatar
thomwolf committed
1667
        lm_logits = self.lm_head(sequence_output)
thomwolf's avatar
thomwolf committed
1668

1669
        loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1670
        if labels is not None:
Lysandre's avatar
Lysandre committed
1671
            loss_fct = CrossEntropyLoss(ignore_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1672
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1673
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
thomwolf's avatar
thomwolf committed
1674

1675
        if not return_dict:
1676
1677
1678
1679
1680
1681
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
1682
            past_key_values=decoder_outputs.past_key_values,
1683
1684
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1685
            cross_attentions=decoder_outputs.cross_attentions,
1686
1687
1688
1689
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
1690

1691
    def prepare_inputs_for_generation(
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
        self,
        input_ids,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs
1702
    ):
1703
1704
1705
1706
1707

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

1708
1709
        return {
            "decoder_input_ids": input_ids,
1710
            "past_key_values": past,
1711
1712
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
1713
1714
1715
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
1716
            "use_cache": use_cache,
1717
1718
        }

1719
1720
1721
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

1722
    def _reorder_cache(self, past, beam_idx):
1723
1724
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
1725
        if past is None:
1726
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1727
1728
1729
            return past

        reordered_decoder_past = ()
1730
        for layer_past_states in past:
1731
1732
1733
1734
1735
1736
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
1737
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1738
1739
1740
1741
1742
1743
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1744
        return reordered_decoder_past
1745
1746
1747


@add_start_docstrings(
1748
    "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1749
1750
1751
1752
    T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
    authorized_missing_keys = [
1753
        r"encoder.embed_tokens.weight",
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
    ]

    def __init__(self, config: T5Config):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

1765
1766
        # Initialize weights and apply final processing
        self.post_init()
1767

Lysandre Debut's avatar
Lysandre Debut committed
1768
1769
1770
1771
        # Model parallel
        self.model_parallel = False
        self.device_map = None

1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)

    def get_encoder(self):
        return self.encoder

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
1813
1814
1815
1816
1817
1818
1819
1820
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
1821
1822
1823
        r"""
        Returns:

1824
        Example:
1825

1826
1827
        ```python
        >>> from transformers import T5Tokenizer, T5EncoderModel
Sylvain Gugger's avatar
Sylvain Gugger committed
1828
1829
1830
1831
1832

        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5EncoderModel.from_pretrained("t5-small")
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1833
        ... ).input_ids  # Batch size 1
1834
1835
1836
        >>> outputs = model(input_ids=input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        return encoder_outputs