modeling_t5.py 80.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" PyTorch T5 model."""
thomwolf's avatar
thomwolf committed
16
17


Aymeric Augustin's avatar
Aymeric Augustin committed
18
import copy
thomwolf's avatar
thomwolf committed
19
20
import math
import os
21
import warnings
22
from typing import Optional, Tuple, Union
thomwolf's avatar
thomwolf committed
23
24

import torch
Aymeric Augustin's avatar
Aymeric Augustin committed
25
from torch import nn
26
from torch.nn import CrossEntropyLoss
27
from torch.utils.checkpoint import checkpoint
thomwolf's avatar
thomwolf committed
28

Patrick von Platen's avatar
Patrick von Platen committed
29
from ...activations import ACT2FN
Sylvain Gugger's avatar
Sylvain Gugger committed
30
from ...modeling_outputs import (
31
32
33
34
35
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
36
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
37
38
39
40
41
42
43
44
45
from ...utils import (
    DUMMY_INPUTS,
    DUMMY_MASK,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_torch_fx_proxy,
    logging,
    replace_return_docstrings,
)
46
from ...utils.model_parallel_utils import assert_device_map, get_device_map
Sylvain Gugger's avatar
Sylvain Gugger committed
47
from .configuration_t5 import T5Config
Aymeric Augustin's avatar
Aymeric Augustin committed
48

thomwolf's avatar
thomwolf committed
49

Lysandre Debut's avatar
Lysandre Debut committed
50
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
51

52
_CONFIG_FOR_DOC = "T5Config"
53
_TOKENIZER_FOR_DOC = "T5Tokenizer"
54
_CHECKPOINT_FOR_DOC = "t5-small"
55

thomwolf's avatar
thomwolf committed
56
####################################################
57
# This dict contains ids and associated url
thomwolf's avatar
thomwolf committed
58
59
# for the pretrained weights provided with the models
####################################################
60
61
62
63
64
65
66
67
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "t5-small",
    "t5-base",
    "t5-large",
    "t5-3b",
    "t5-11b",
    # See all T5 models at https://huggingface.co/models?filter=t5
]
thomwolf's avatar
thomwolf committed
68

69

thomwolf's avatar
thomwolf committed
70
71
72
73
74
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
Lysandre's avatar
Lysandre committed
75
    """Load tf checkpoints in a pytorch model."""
thomwolf's avatar
thomwolf committed
76
77
    try:
        import re
78

thomwolf's avatar
thomwolf committed
79
80
81
        import numpy as np
        import tensorflow as tf
    except ImportError:
82
83
84
85
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
86
87
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
88
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
thomwolf's avatar
thomwolf committed
89
90
91
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
92
    tf_weights = {}
thomwolf's avatar
thomwolf committed
93
    for name, shape in init_vars:
94
        logger.info(f"Loading TF weight {name} with shape {shape}")
thomwolf's avatar
thomwolf committed
95
96
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
97
        tf_weights[name] = array
thomwolf's avatar
thomwolf committed
98

99
    for txt_name in names:
100
        name = txt_name.split("/")
thomwolf's avatar
thomwolf committed
101
102
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
103
104
105
106
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
107
            logger.info(f"Skipping {'/'.join(name)}")
108
109
            tf_weights.pop(txt_name, None)
            continue
110
        if "_slot_" in name[-1]:
111
            logger.info(f"Skipping {'/'.join(name)}")
112
            tf_weights.pop(txt_name, None)
thomwolf's avatar
thomwolf committed
113
114
            continue
        pointer = model
115
        array = tf_weights[txt_name]
Patrick von Platen's avatar
Patrick von Platen committed
116

thomwolf's avatar
thomwolf committed
117
        for m_name in name:
118
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
119
                scope_names = re.split(r"_(\d+)", m_name)
thomwolf's avatar
thomwolf committed
120
            else:
121
122
                scope_names = [m_name]
            if scope_names[0] in ["kernel", "scale", "embedding"]:
123
                pointer = getattr(pointer, "weight")
Patrick von Platen's avatar
Patrick von Platen committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
            elif scope_names[0] == "self_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[0]
            elif scope_names[0] == "enc_dec_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[1]
            elif scope_names[0] == "dense_relu_dense":
                pointer = getattr(pointer, "layer")
                pointer = pointer[2]
            elif scope_names[0] == "rms_norm":
                if hasattr(pointer, "layer_norm"):
                    pointer = getattr(pointer, "layer_norm")
                elif hasattr(pointer, "final_layer_norm"):
                    pointer = getattr(pointer, "final_layer_norm")
138
139
140
141
142
143
            elif scope_names[0] == "scale":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
Patrick von Platen's avatar
Patrick von Platen committed
144
145
146
147
            elif scope_names[0] == "decoder" and name[1] == "logits":
                continue
            elif scope_names[0] == "logits":
                pointer = getattr(pointer, "lm_head")
Patrick von Platen's avatar
Patrick von Platen committed
148
149
150
            elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
                pointer = getattr(pointer, f"wi_{scope_names[1]}")
                continue
thomwolf's avatar
thomwolf committed
151
152
            else:
                try:
153
                    pointer = getattr(pointer, scope_names[0])
thomwolf's avatar
thomwolf committed
154
                except AttributeError:
155
                    logger.info(f"Skipping {'/'.join(name)}")
thomwolf's avatar
thomwolf committed
156
                    continue
157
158
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
159
                pointer = pointer[num]
160
        if scope_names[0] not in ["kernel", "scale", "embedding"]:
161
            pointer = getattr(pointer, "weight")
162
        if scope_names[0] != "embedding":
163
            logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
thomwolf's avatar
thomwolf committed
164
165
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
166
167
168
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
thomwolf's avatar
thomwolf committed
169
170
171
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
172
        logger.info(f"Initialize PyTorch weight {name}")
173
174
175
        pointer.data = torch.from_numpy(array.astype(np.float32))
        tf_weights.pop(txt_name, None)

176
    logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
thomwolf's avatar
thomwolf committed
177
178
179
180
181
182
    return model


####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
183
# - PreTrainedModel for the models (it-self a sub-class of nn.Module)
thomwolf's avatar
thomwolf committed
184
####################################################
185
PARALLELIZE_DOCSTRING = r"""
Stas Bekman's avatar
Stas Bekman committed
186
187
    This is an experimental feature and is a subject to change at a moment's notice.

188
189
190
191
    Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
    it will evenly distribute blocks across all devices.

    Args:
192
        device_map (`Dict[int, list]`, optional, defaults to None):
193
194
195
196
197
198
199
200
201
202
203
            A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
            automatically mapped to the first device (for esoteric reasons). That means that the first device should
            have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
            following number of attention modules:

                - t5-small: 6
                - t5-base: 12
                - t5-large: 24
                - t5-3b: 24
                - t5-11b: 24

204
    Example:
205

206
207
    ```python
    # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
Sylvain Gugger's avatar
Sylvain Gugger committed
208
209
210
211
212
213
214
    model = T5ForConditionalGeneration.from_pretrained("t5-3b")
    device_map = {
        0: [0, 1, 2],
        1: [3, 4, 5, 6, 7, 8, 9],
        2: [10, 11, 12, 13, 14, 15, 16],
        3: [17, 18, 19, 20, 21, 22, 23],
    }
215
216
    model.parallelize(device_map)
    ```
217
218
219
220
"""
DEPARALLELIZE_DOCSTRING = r"""
    Moves the model to cpu from a model parallel state.

221
    Example:
222

223
224
    ```python
    # On a 4 GPU machine with t5-3b:
Sylvain Gugger's avatar
Sylvain Gugger committed
225
226
227
228
229
230
231
232
233
    model = T5ForConditionalGeneration.from_pretrained("t5-3b")
    device_map = {
        0: [0, 1, 2],
        1: [3, 4, 5, 6, 7, 8, 9],
        2: [10, 11, 12, 13, 14, 15, 16],
        3: [17, 18, 19, 20, 21, 22, 23],
    }
    model.parallelize(device_map)  # Splits the model across several devices
    model.deparallelize()  # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
234
    ```
235
"""
thomwolf's avatar
thomwolf committed
236

237

thomwolf's avatar
thomwolf committed
238
239
class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
Sylvain Gugger's avatar
Sylvain Gugger committed
240
        """
241
        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
thomwolf's avatar
thomwolf committed
242
        """
Julien Chaumond's avatar
Julien Chaumond committed
243
        super().__init__()
thomwolf's avatar
thomwolf committed
244
245
246
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

247
    def forward(self, hidden_states):
248
249
250
251
252
253

        # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
        # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
        # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
        # half-precision inputs is done in fp32

254
255
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
256

257
258
259
260
        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

261
        return self.weight * hidden_states
thomwolf's avatar
thomwolf committed
262
263


264
265
266
267
268
269
270
271
272
273
274
275
276
277
try:
    from apex.normalization import FusedRMSNorm

    T5LayerNorm = FusedRMSNorm  # noqa

    logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
except ImportError:
    # using the normal T5LayerNorm
    pass
except Exception:
    logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
    pass


thomwolf's avatar
thomwolf committed
278
class T5DenseReluDense(nn.Module):
279
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
280
        super().__init__()
thomwolf's avatar
thomwolf committed
281
282
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
thomwolf's avatar
thomwolf committed
283
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
284
285

    def forward(self, hidden_states):
286
        hidden_states = self.wi(hidden_states)
287
        hidden_states = nn.functional.relu(hidden_states)
288
289
290
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
291
292


Patrick von Platen's avatar
Patrick von Platen committed
293
class T5DenseGatedGeluDense(nn.Module):
294
    def __init__(self, config: T5Config):
Patrick von Platen's avatar
Patrick von Platen committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.gelu_act = ACT2FN["gelu_new"]

    def forward(self, hidden_states):
        hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states


thomwolf's avatar
thomwolf committed
311
class T5LayerFF(nn.Module):
312
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
313
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
314
315
316
317
318
319
320
321
322
        if config.feed_forward_proj == "relu":
            self.DenseReluDense = T5DenseReluDense(config)
        elif config.feed_forward_proj == "gated-gelu":
            self.DenseReluDense = T5DenseGatedGeluDense(config)
        else:
            raise ValueError(
                f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
            )

thomwolf's avatar
thomwolf committed
323
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
324
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
325
326

    def forward(self, hidden_states):
327
328
329
330
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
331
332
333


class T5Attention(nn.Module):
334
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
335
        super().__init__()
thomwolf's avatar
thomwolf committed
336
        self.is_decoder = config.is_decoder
thomwolf's avatar
thomwolf committed
337
        self.has_relative_attention_bias = has_relative_attention_bias
thomwolf's avatar
thomwolf committed
338
        self.relative_attention_num_buckets = config.relative_attention_num_buckets
339
        self.relative_attention_max_distance = config.relative_attention_max_distance
340
        self.d_model = config.d_model
341
        self.key_value_proj_dim = config.d_kv
thomwolf's avatar
thomwolf committed
342
343
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
344
        self.inner_dim = self.n_heads * self.key_value_proj_dim
thomwolf's avatar
thomwolf committed
345

346
        # Mesh TensorFlow initialization to avoid scaling before softmax
347
348
349
350
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
thomwolf's avatar
thomwolf committed
351

thomwolf's avatar
thomwolf committed
352
353
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
thomwolf's avatar
thomwolf committed
354
        self.pruned_heads = set()
355
        self.gradient_checkpointing = False
thomwolf's avatar
thomwolf committed
356
357
358
359

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
360
361
362
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
thomwolf's avatar
thomwolf committed
363
364
365
366
367
368
369
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
370
        self.inner_dim = self.key_value_proj_dim * self.n_heads
thomwolf's avatar
thomwolf committed
371
372
373
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
374
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
thomwolf's avatar
thomwolf committed
375
376
377
378
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

Sylvain Gugger's avatar
Sylvain Gugger committed
379
380
381
382
383
384
385
        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

thomwolf's avatar
thomwolf committed
386
387
388
389
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
390
            max_distance: an integer
Sylvain Gugger's avatar
Sylvain Gugger committed
391

thomwolf's avatar
thomwolf committed
392
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
393
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
thomwolf's avatar
thomwolf committed
394
        """
395
        relative_buckets = 0
thomwolf's avatar
thomwolf committed
396
397
        if bidirectional:
            num_buckets //= 2
398
399
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
thomwolf's avatar
thomwolf committed
400
        else:
401
402
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)
thomwolf's avatar
thomwolf committed
403
404
405

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
406
        is_small = relative_position < max_exact
thomwolf's avatar
thomwolf committed
407
408

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
409
410
411
412
        relative_postion_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
413
        ).to(torch.long)
414
415
416
        relative_postion_if_large = torch.min(
            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
        )
thomwolf's avatar
thomwolf committed
417

418
419
        relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
        return relative_buckets
thomwolf's avatar
thomwolf committed
420

421
    def compute_bias(self, query_length, key_length):
Patrick von Platen's avatar
Patrick von Platen committed
422
        """Compute binned relative position bias"""
423
424
425
426
427
428
        context_position = torch.arange(
            query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
        )[:, None]
        memory_position = torch.arange(
            key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
        )[None, :]
429
430
431
432
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
433
            num_buckets=self.relative_attention_num_buckets,
434
            max_distance=self.relative_attention_max_distance,
435
        )
436
437
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
thomwolf's avatar
thomwolf committed
438
439
        return values

440
441
    def forward(
        self,
442
        hidden_states,
443
        mask=None,
444
        key_value_states=None,
445
        position_bias=None,
446
        past_key_value=None,
447
        layer_head_mask=None,
448
        query_length=None,
449
        use_cache=False,
450
        output_attentions=False,
451
    ):
thomwolf's avatar
thomwolf committed
452
        """
453
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
thomwolf's avatar
thomwolf committed
454
        """
455
456
457
458
459
460
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length
461

462
        if past_key_value is not None:
463
            assert (
464
                len(past_key_value) == 2
465
            ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
466
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
467

468
        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
thomwolf's avatar
thomwolf committed
469

470
        def shape(states):
Patrick von Platen's avatar
Patrick von Platen committed
471
            """projection"""
472
473
474
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
Patrick von Platen's avatar
Patrick von Platen committed
475
            """reshape"""
476
477
478
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
Patrick von Platen's avatar
Patrick von Platen committed
479
            """projects hidden states correctly to key/query states"""
480
481
482
483
484
485
486
487
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
488

489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
thomwolf's avatar
thomwolf committed
509

510
        # compute scores
Abel's avatar
Abel committed
511
        scores = torch.matmul(
512
513
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
thomwolf's avatar
thomwolf committed
514
515

        if position_bias is None:
thomwolf's avatar
thomwolf committed
516
            if not self.has_relative_attention_bias:
517
518
519
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
520
                if self.gradient_checkpointing and self.training:
521
                    position_bias.requires_grad = True
522
523
            else:
                position_bias = self.compute_bias(real_seq_length, key_length)
524
525
526

            # if key and values are already calculated
            # we want only the last query position bias
527
            if past_key_value is not None:
528
                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
529

thomwolf's avatar
thomwolf committed
530
            if mask is not None:
531
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
532

thomwolf's avatar
thomwolf committed
533
        scores += position_bias
534
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
535
536
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
537
        attn_weights = nn.functional.dropout(
538
539
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
540
541

        # Mask heads if we want to
542
543
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask
thomwolf's avatar
thomwolf committed
544

545
546
        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)
thomwolf's avatar
thomwolf committed
547

548
549
        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
550

551
        if output_attentions:
552
            outputs = outputs + (attn_weights,)
thomwolf's avatar
thomwolf committed
553
        return outputs
thomwolf's avatar
thomwolf committed
554
555
556


class T5LayerSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
557
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
558
        super().__init__()
559
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
560
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
561
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
562

563
    def forward(
564
565
566
567
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
568
        layer_head_mask=None,
569
        past_key_value=None,
570
        use_cache=False,
571
        output_attentions=False,
572
    ):
573
        normed_hidden_states = self.layer_norm(hidden_states)
574
        attention_output = self.SelfAttention(
575
            normed_hidden_states,
576
577
            mask=attention_mask,
            position_bias=position_bias,
578
            layer_head_mask=layer_head_mask,
579
            past_key_value=past_key_value,
580
            use_cache=use_cache,
581
            output_attentions=output_attentions,
582
        )
583
584
        hidden_states = hidden_states + self.dropout(attention_output[0])
        outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them
thomwolf's avatar
thomwolf committed
585
        return outputs
thomwolf's avatar
thomwolf committed
586
587


thomwolf's avatar
thomwolf committed
588
class T5LayerCrossAttention(nn.Module):
589
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
590
        super().__init__()
591
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
thomwolf's avatar
thomwolf committed
592
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
593
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
594

595
596
597
    def forward(
        self,
        hidden_states,
598
        key_value_states,
599
600
        attention_mask=None,
        position_bias=None,
601
        layer_head_mask=None,
602
        past_key_value=None,
603
        use_cache=False,
604
        query_length=None,
605
        output_attentions=False,
606
    ):
607
        normed_hidden_states = self.layer_norm(hidden_states)
608
        attention_output = self.EncDecAttention(
609
            normed_hidden_states,
610
            mask=attention_mask,
611
            key_value_states=key_value_states,
612
            position_bias=position_bias,
613
            layer_head_mask=layer_head_mask,
614
            past_key_value=past_key_value,
615
            use_cache=use_cache,
616
            query_length=query_length,
617
            output_attentions=output_attentions,
618
        )
619
        layer_output = hidden_states + self.dropout(attention_output[0])
thomwolf's avatar
thomwolf committed
620
621
622
623
624
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


class T5Block(nn.Module):
thomwolf's avatar
thomwolf committed
625
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
626
        super().__init__()
thomwolf's avatar
thomwolf committed
627
        self.is_decoder = config.is_decoder
628
629
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
thomwolf's avatar
thomwolf committed
630
        if self.is_decoder:
631
            self.layer.append(T5LayerCrossAttention(config))
632
633

        self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
634

635
636
637
638
639
640
641
642
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
643
        layer_head_mask=None,
644
        cross_attn_layer_head_mask=None,
645
        past_key_value=None,
646
        use_cache=False,
647
        output_attentions=False,
648
        return_dict=True,
649
    ):
650

651
        if past_key_value is not None:
652
653
            if not self.is_decoder:
                logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
654
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
655

656
657
658
            if len(past_key_value) != expected_num_past_key_values:
                raise ValueError(
                    f"There should be {expected_num_past_key_values} past states. "
659
                    f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
660
661
                    f"Got {len(past_key_value)} past key / value states"
                )
662

663
664
            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
665
        else:
666
            self_attn_past_key_value, cross_attn_past_key_value = None, None
667

668
        self_attention_outputs = self.layer[0](
669
670
671
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
672
            layer_head_mask=layer_head_mask,
673
            past_key_value=self_attn_past_key_value,
674
            use_cache=use_cache,
675
            output_attentions=output_attentions,
676
        )
677
678
679
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

Suraj Patil's avatar
Suraj Patil committed
680
        # clamp inf values to enable fp16 training
681
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
682
683
684
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

685
686
        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
687
688
689
690
691
692
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None
thomwolf's avatar
thomwolf committed
693

694
695
            cross_attention_outputs = self.layer[1](
                hidden_states,
696
                key_value_states=encoder_hidden_states,
697
698
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
699
                layer_head_mask=cross_attn_layer_head_mask,
700
                past_key_value=cross_attn_past_key_value,
701
                query_length=query_length,
702
                use_cache=use_cache,
703
                output_attentions=output_attentions,
704
            )
thomwolf's avatar
thomwolf committed
705
            hidden_states = cross_attention_outputs[0]
706
707
708

            # clamp inf values to enable fp16 training
            if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
709
710
711
                clamp_value = torch.finfo(hidden_states.dtype).max - 1000
                hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

712
713
714
715
716
717
718
719
720
            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)
721
722
723

        # clamp inf values to enable fp16 training
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
724
725
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
726

727
        outputs = (hidden_states,)
thomwolf's avatar
thomwolf committed
728

729
730
731
732
733
        if use_cache:
            outputs = outputs + (present_key_value_state,) + attention_outputs
        else:
            outputs = outputs + attention_outputs

734
        return outputs  # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
thomwolf's avatar
thomwolf committed
735
736


thomwolf's avatar
thomwolf committed
737
class T5PreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
738
739
740
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
thomwolf's avatar
thomwolf committed
741
    """
742

thomwolf's avatar
thomwolf committed
743
744
745
    config_class = T5Config
    load_tf_weights = load_tf_weights_in_t5
    base_model_prefix = "transformer"
746
    is_parallelizable = True
747
    supports_gradient_checkpointing = True
thomwolf's avatar
thomwolf committed
748

749
750
751
752
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
753
754
        dummy_inputs = {
            "decoder_input_ids": input_ids,
755
            "input_ids": input_ids,
756
757
            "decoder_attention_mask": input_mask,
        }
758
759
        return dummy_inputs

thomwolf's avatar
thomwolf committed
760
    def _init_weights(self, module):
Patrick von Platen's avatar
Patrick von Platen committed
761
        """Initialize the weights"""
762
        factor = self.config.initializer_factor  # Used for testing weights initialization
thomwolf's avatar
thomwolf committed
763
        if isinstance(module, T5LayerNorm):
764
            module.weight.data.fill_(factor * 1.0)
765
        elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
766
767
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
768
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
769
770
771
772
        elif isinstance(module, T5DenseReluDense):
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
773
774
            module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi, "bias") and module.wi.bias is not None:
775
                module.wi.bias.data.zero_()
776
777
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
778
                module.wo.bias.data.zero_()
Patrick von Platen's avatar
Patrick von Platen committed
779
780
781
782
783
784
785
786
787
788
        elif isinstance(module, T5DenseGatedGeluDense):
            module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
                module.wi_0.bias.data.zero_()
            module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
                module.wi_1.bias.data.zero_()
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
                module.wo.bias.data.zero_()
789
790
791
792
        elif isinstance(module, T5Attention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
793
            key_value_proj_dim = self.config.d_kv
794
            n_heads = self.config.num_heads
795
            module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
796
797
            module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
            module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
798
            module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
799
            if module.has_relative_attention_bias:
800
                module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
thomwolf's avatar
thomwolf committed
801

802
803
804
805
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (T5Attention, T5Stack)):
            module.gradient_checkpointing = value

806
807
808
809
810
811
812
813
814
    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

        assert (
            decoder_start_token_id is not None
        ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"

        # shift inputs to the right
815
816
817
818
819
820
821
822
        if is_torch_fx_proxy(input_ids):
            # Item assignment is not supported natively for proxies.
            shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
            shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        else:
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
            shifted_input_ids[..., 0] = decoder_start_token_id
823
824

        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
Sylvain Gugger's avatar
Sylvain Gugger committed
825
        # replace possible -100 values in labels by `pad_token_id`
826
827
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

828
        assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
829
830
831

        return shifted_input_ids

thomwolf's avatar
thomwolf committed
832
833

class T5Stack(T5PreTrainedModel):
834
    def __init__(self, config, embed_tokens=None):
Julien Chaumond's avatar
Julien Chaumond committed
835
        super().__init__(config)
836
837

        self.embed_tokens = embed_tokens
thomwolf's avatar
thomwolf committed
838
839
        self.is_decoder = config.is_decoder

840
841
842
        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
        )
thomwolf's avatar
thomwolf committed
843
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
844
845
        self.dropout = nn.Dropout(config.dropout_rate)

846
847
        # Initialize weights and apply final processing
        self.post_init()
848
849
850
        # Model parallel
        self.model_parallel = False
        self.device_map = None
851
        self.gradient_checkpointing = False
852
853
854
855
856

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        # Check validity of device_map
        self.device_map = (
857
            get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
        )
        assert_device_map(self.device_map, len(self.block))
        self.model_parallel = True
        self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
        self.last_device = "cuda:" + str(max(self.device_map.keys()))
        # Load onto devices
        for k, v in self.device_map.items():
            for layer in v:
                cuda_device = "cuda:" + str(k)
                self.block[layer] = self.block[layer].to(cuda_device)

        # Set embed_tokens to first layer
        self.embed_tokens = self.embed_tokens.to(self.first_device)
        # Set final layer norm to last device
        self.final_layer_norm = self.final_layer_norm.to(self.last_device)

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.model_parallel = False
        self.device_map = None
        self.first_device = "cpu"
        self.last_device = "cpu"
        for i in range(len(self.block)):
            self.block[i] = self.block[i].to("cpu")
        self.embed_tokens = self.embed_tokens.to("cpu")
        self.final_layer_norm = self.final_layer_norm.to("cpu")
        torch.cuda.empty_cache()
thomwolf's avatar
thomwolf committed
885

886
887
888
889
890
891
    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

892
893
    def forward(
        self,
894
        input_ids=None,
895
896
897
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
898
        inputs_embeds=None,
899
        head_mask=None,
900
        cross_attn_head_mask=None,
901
        past_key_values=None,
902
        use_cache=None,
903
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
904
        output_hidden_states=None,
905
        return_dict=None,
906
    ):
907
908
909
910
        # Model parallel
        if self.model_parallel:
            torch.cuda.set_device(self.first_device)
            self.embed_tokens = self.embed_tokens.to(self.first_device)
911
        use_cache = use_cache if use_cache is not None else self.config.use_cache
912
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
913
914
915
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
916
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
917

918
        if input_ids is not None and inputs_embeds is not None:
919
920
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
Jonathan Chang's avatar
Jonathan Chang committed
921
                f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
922
            )
923
924
925
926
927
928
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
929
            err_msg_prefix = "decoder_" if self.is_decoder else ""
Jonathan Chang's avatar
Jonathan Chang committed
930
            raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
931
932

        if inputs_embeds is None:
933
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
934
935
936
937
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape

938
939
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
940

941
        if use_cache is True:
Stas Bekman's avatar
Stas Bekman committed
942
            assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
943

thomwolf's avatar
thomwolf committed
944
        if attention_mask is None:
945
946
            attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
thomwolf's avatar
thomwolf committed
947
            encoder_seq_length = encoder_hidden_states.shape[1]
948
949
950
            encoder_attention_mask = torch.ones(
                batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
            )
thomwolf's avatar
thomwolf committed
951

952
953
954
        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)
955

lexhuismans's avatar
lexhuismans committed
956
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
thomwolf's avatar
thomwolf committed
957
        # ourselves in which case we just need to make it broadcastable to all heads.
958
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
thomwolf's avatar
thomwolf committed
959

960
961
962
963
964
965
966
        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
967
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
thomwolf's avatar
thomwolf committed
968
969
        else:
            encoder_extended_attention_mask = None
thomwolf's avatar
thomwolf committed
970
971

        # Prepare head mask if needed
972
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
973
        cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
974
975
976
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
977
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
thomwolf's avatar
thomwolf committed
978
        position_bias = None
thomwolf's avatar
thomwolf committed
979
        encoder_decoder_position_bias = None
thomwolf's avatar
thomwolf committed
980

981
        hidden_states = self.dropout(inputs_embeds)
982

983
        for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
984
            layer_head_mask = head_mask[i]
985
            cross_attn_layer_head_mask = cross_attn_head_mask[i]
986
987
988
989
990
991
992
993
994
995
996
997
998
999
            # Model parallel
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                # Ensure that attention_mask is always on the same device as hidden_states
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if position_bias is not None:
                    position_bias = position_bias.to(hidden_states.device)
                if encoder_hidden_states is not None:
                    encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
                if encoder_extended_attention_mask is not None:
                    encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
                if encoder_decoder_position_bias is not None:
                    encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
1000
1001
                if layer_head_mask is not None:
                    layer_head_mask = layer_head_mask.to(hidden_states.device)
1002
1003
                if cross_attn_layer_head_mask is not None:
                    cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
Joseph Liu's avatar
Joseph Liu committed
1004
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
1005
1006
                all_hidden_states = all_hidden_states + (hidden_states,)

1007
            if self.gradient_checkpointing and self.training:
1008
                if use_cache:
1009
                    logger.warning(
1010
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return tuple(module(*inputs, use_cache, output_attentions))

                    return custom_forward

                layer_outputs = checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    extended_attention_mask,
                    position_bias,
                    encoder_hidden_states,
                    encoder_extended_attention_mask,
                    encoder_decoder_position_bias,
                    layer_head_mask,
                    cross_attn_layer_head_mask,
                    None,  # past_key_value is always None with gradient checkpointing
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_extended_attention_mask,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    layer_head_mask=layer_head_mask,
                    cross_attn_layer_head_mask=cross_attn_layer_head_mask,
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

thomwolf's avatar
thomwolf committed
1047
            # layer_outputs is a tuple with:
1048
            # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1049
1050
            if use_cache is False:
                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1051

1052
            hidden_states, present_key_value_state = layer_outputs[:2]
1053

1054
            # We share the position biases between the layers - the first layer store them
1055
1056
            # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
            # (cross-attention position bias), (cross-attention weights)
1057
1058
1059
            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1060
            # append next layer key value states
1061
1062
            if use_cache:
                present_key_value_states = present_key_value_states + (present_key_value_state,)
thomwolf's avatar
thomwolf committed
1063

1064
            if output_attentions:
1065
                all_attentions = all_attentions + (layer_outputs[3],)
1066
                if self.is_decoder:
1067
                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
thomwolf's avatar
thomwolf committed
1068

1069
1070
1071
1072
1073
1074
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

thomwolf's avatar
thomwolf committed
1075
        hidden_states = self.final_layer_norm(hidden_states)
thomwolf's avatar
thomwolf committed
1076
        hidden_states = self.dropout(hidden_states)
thomwolf's avatar
thomwolf committed
1077
1078

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
1079
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
1080
1081
            all_hidden_states = all_hidden_states + (hidden_states,)

1082
        if not return_dict:
1083
1084
            return tuple(
                v
1085
1086
1087
1088
1089
1090
1091
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
1092
1093
                if v is not None
            )
1094
        return BaseModelOutputWithPastAndCrossAttentions(
1095
1096
1097
1098
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
1099
            cross_attentions=all_cross_attentions,
1100
        )
thomwolf's avatar
thomwolf committed
1101
1102


1103
T5_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1104

Sylvain Gugger's avatar
Sylvain Gugger committed
1105
1106
1107
1108
    The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
    Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
    Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
    text-to-text denoising generative setting.
thomwolf's avatar
thomwolf committed
1109

Sylvain Gugger's avatar
Sylvain Gugger committed
1110
1111
1112
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)
Sylvain Gugger's avatar
Sylvain Gugger committed
1113

Sylvain Gugger's avatar
Sylvain Gugger committed
1114
1115
1116
    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.
thomwolf's avatar
thomwolf committed
1117
1118

    Parameters:
1119
        config ([`T5Config`]): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
1120
            Initializing with a config file does not load the weights associated with the model, only the
Sylvain Gugger's avatar
Sylvain Gugger committed
1121
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
thomwolf's avatar
thomwolf committed
1122
1123
1124
"""

T5_INPUTS_DOCSTRING = r"""
Patrick von Platen's avatar
Patrick von Platen committed
1125
    Args:
1126
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1127
1128
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.
Sylvain Gugger's avatar
Sylvain Gugger committed
1129

Sylvain Gugger's avatar
Sylvain Gugger committed
1130
1131
            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for detail.
Sylvain Gugger's avatar
Sylvain Gugger committed
1132

1133
            [What are input IDs?](../glossary#input-ids)
Sylvain Gugger's avatar
Sylvain Gugger committed
1134

1135
1136
1137
            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Sylvain Gugger's avatar
Sylvain Gugger committed
1138
1139

            - 1 for tokens that are **not masked**,
1140
            - 0 for tokens that are **masked**.
Sylvain Gugger's avatar
Sylvain Gugger committed
1141

1142
1143
            [What are attention masks?](../glossary#attention-mask)
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1144
1145
            Indices of decoder input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
1146
1147
            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
1148

1149
            [What are decoder input IDs?](../glossary#decoder-input-ids)
1150

Sylvain Gugger's avatar
Sylvain Gugger committed
1151
1152
            T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1153

Sylvain Gugger's avatar
Sylvain Gugger committed
1154
1155
            To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
            Training](./t5#training).
1156
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1157
1158
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
1159
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1160
1161
            Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
            1]`:
1162
1163
1164
1165

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1166
        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1167
1168
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
            1]`:
1169
1170
1171
1172

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1173
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1174
                Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1175
                `[0, 1]`:
1176
1177
1178
1179

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

1180
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1181
1182
1183
            Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
            the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1184
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1185
1186
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

Sylvain Gugger's avatar
Sylvain Gugger committed
1187
1188
1189
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1190
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1191
1192
1193
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
1194
1195
        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
Sylvain Gugger's avatar
Sylvain Gugger committed
1196
1197
            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
            input (see `past_key_values`). This is useful if you want more control over how to convert
1198
            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
Sylvain Gugger's avatar
Sylvain Gugger committed
1199

Sylvain Gugger's avatar
Sylvain Gugger committed
1200
1201
            If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
            of `inputs_embeds`.
Sylvain Gugger's avatar
Sylvain Gugger committed
1202

1203
        use_cache (`bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1204
1205
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1206

1207
1208
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
Sylvain Gugger's avatar
Sylvain Gugger committed
1209
            tensors for more detail.
1210
1211
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
Sylvain Gugger's avatar
Sylvain Gugger committed
1212
            more detail.
1213
        return_dict (`bool`, *optional*):
1214
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
thomwolf's avatar
thomwolf committed
1215
1216
"""

1217
1218
T5_ENCODER_INPUTS_DOCSTRING = r"""
    Args:
1219
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1220
1221
1222
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.

Sylvain Gugger's avatar
Sylvain Gugger committed
1223
1224
            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for detail.
1225

1226
1227
1228
            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1229
1230
1231
1232

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

1233
1234
1235
            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1236
1237
1238
1239

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1240
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1241
1242
1243
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
1244
1245
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1246
            tensors for more detail.
1247
1248
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1249
            more detail.
1250
        return_dict (`bool`, *optional*):
1251
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1252
1253
"""

1254
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1255
1256
1257
1258
1259
1260
1261
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""

1262
1263

@add_start_docstrings(
NielsRogge's avatar
NielsRogge committed
1264
    "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1265
1266
    T5_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1267
class T5Model(T5PreTrainedModel):
1268
    _keys_to_ignore_on_load_missing = [
1269
1270
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1271
1272
    ]
    _keys_to_ignore_on_load_unexpected = [
1273
1274
1275
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

1276
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
1277
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1278
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
thomwolf's avatar
thomwolf committed
1279
1280

        encoder_config = copy.deepcopy(config)
1281
        encoder_config.is_decoder = False
1282
        encoder_config.use_cache = False
1283
        encoder_config.is_encoder_decoder = False
1284
        self.encoder = T5Stack(encoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1285

thomwolf's avatar
thomwolf committed
1286
1287
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1288
        decoder_config.is_encoder_decoder = False
1289
        decoder_config.num_layers = config.num_decoder_layers
1290
        self.decoder = T5Stack(decoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1291

1292
1293
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1294

1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

thomwolf's avatar
thomwolf committed
1321
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
1322
        return self.shared
thomwolf's avatar
thomwolf committed
1323
1324

    def set_input_embeddings(self, new_embeddings):
thomwolf's avatar
thomwolf committed
1325
        self.shared = new_embeddings
1326
1327
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
thomwolf's avatar
thomwolf committed
1328

1329
1330
1331
1332
1333
1334
    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

thomwolf's avatar
thomwolf committed
1335
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
1336
1337
1338
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
thomwolf's avatar
thomwolf committed
1339
1340
1341
1342
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

1343
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1344
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1345
1346
    def forward(
        self,
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        decoder_inputs_embeds: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
Patrick von Platen's avatar
Patrick von Platen committed
1363
        r"""
Lysandre's avatar
Lysandre committed
1364
        Returns:
Patrick von Platen's avatar
Patrick von Platen committed
1365

1366
        Example:
1367

1368
1369
        ```python
        >>> from transformers import T5Tokenizer, T5Model
Patrick von Platen's avatar
Patrick von Platen committed
1370

Sylvain Gugger's avatar
Sylvain Gugger committed
1371
1372
        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5Model.from_pretrained("t5-small")
Patrick von Platen's avatar
Patrick von Platen committed
1373

Sylvain Gugger's avatar
Sylvain Gugger committed
1374
1375
1376
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        >>> ).input_ids  # Batch size 1
1377
        >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
Patrick von Platen's avatar
Patrick von Platen committed
1378

1379
1380
1381
1382
        >>> # forward pass
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""
1383
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1384
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
thomwolf's avatar
thomwolf committed
1385

1386
1387
1388
1389
1390
1391
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

thomwolf's avatar
thomwolf committed
1392
        # Encode if needed (training, first prediction pass)
1393
1394
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
1395
1396
1397
1398
1399
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1400
                output_hidden_states=output_hidden_states,
1401
                return_dict=return_dict,
1402
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
1403
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1404
1405
1406
1407
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1408
            )
thomwolf's avatar
thomwolf committed
1409

1410
        hidden_states = encoder_outputs[0]
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
thomwolf's avatar
thomwolf committed
1423

1424
1425
1426
1427
1428
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1429
            past_key_values=past_key_values,
1430
1431
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1432
            head_mask=decoder_head_mask,
1433
            cross_attn_head_mask=cross_attn_head_mask,
1434
            use_cache=use_cache,
1435
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1436
            output_hidden_states=output_hidden_states,
1437
            return_dict=return_dict,
1438
        )
thomwolf's avatar
thomwolf committed
1439

1440
        if not return_dict:
1441
1442
1443
1444
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
1445
            past_key_values=decoder_outputs.past_key_values,
1446
1447
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1448
            cross_attentions=decoder_outputs.cross_attentions,
1449
1450
1451
1452
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1453
1454


Sylvain Gugger's avatar
Sylvain Gugger committed
1455
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
1456
class T5ForConditionalGeneration(T5PreTrainedModel):
1457
    _keys_to_ignore_on_load_missing = [
1458
1459
1460
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1461
1462
    ]
    _keys_to_ignore_on_load_unexpected = [
1463
1464
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]
1465

1466
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
1467
        super().__init__(config)
1468
        self.model_dim = config.d_model
thomwolf's avatar
thomwolf committed
1469

1470
1471
1472
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
1473
        encoder_config.is_decoder = False
1474
        encoder_config.use_cache = False
1475
        encoder_config.is_encoder_decoder = False
1476
        self.encoder = T5Stack(encoder_config, self.shared)
1477
1478
1479

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1480
        decoder_config.is_encoder_decoder = False
1481
        decoder_config.num_layers = config.num_decoder_layers
1482
        self.decoder = T5Stack(decoder_config, self.shared)
1483

1484
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
1485

1486
1487
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1488

1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.decoder.first_device)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1517
1518
1519
1520
1521
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
1522
1523
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
1524

1525
1526
1527
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

thomwolf's avatar
thomwolf committed
1528
1529
1530
    def get_output_embeddings(self):
        return self.lm_head

1531
1532
    def get_encoder(self):
        return self.encoder
thomwolf's avatar
thomwolf committed
1533

1534
1535
1536
    def get_decoder(self):
        return self.decoder

1537
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1538
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1539
1540
    def forward(
        self,
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
Patrick von Platen's avatar
Patrick von Platen committed
1558
        r"""
1559
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1560
1561
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1562
            labels in `[0, ..., config.vocab_size]`
Lysandre's avatar
Lysandre committed
1563
1564
1565

        Returns:

1566
        Examples:
Lysandre's avatar
Lysandre committed
1567

1568
1569
        ```python
        >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
Lysandre's avatar
Lysandre committed
1570

Sylvain Gugger's avatar
Sylvain Gugger committed
1571
1572
        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1573
1574

        >>> # training
Sylvain Gugger's avatar
Sylvain Gugger committed
1575
1576
        >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
        >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1577
1578
1579
1580
1581
        >>> outputs = model(input_ids=input_ids, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits

        >>> # inference
Sylvain Gugger's avatar
Sylvain Gugger committed
1582
1583
1584
        >>> input_ids = tokenizer(
        ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
        >>> ).input_ids  # Batch size 1
1585
1586
1587
1588
        >>> outputs = model.generate(input_ids)
        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        >>> # studies have shown that owning a dog is good for you.
        ```"""
1589
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1590
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1591

1592
1593
1594
1595
1596
1597
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

1598
        # Encode if needed (training, first prediction pass)
1599
        if encoder_outputs is None:
thomwolf's avatar
thomwolf committed
1600
            # Convert encoder inputs in embeddings if needed
1601
            encoder_outputs = self.encoder(
1602
1603
1604
1605
1606
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1607
                output_hidden_states=output_hidden_states,
1608
                return_dict=return_dict,
1609
            )
1610
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1611
1612
1613
1614
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1615
            )
thomwolf's avatar
thomwolf committed
1616

1617
        hidden_states = encoder_outputs[0]
1618

1619
1620
1621
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

Sylvain Gugger's avatar
Sylvain Gugger committed
1622
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1623
            # get decoder inputs from shifting lm labels to the right
Sylvain Gugger's avatar
Sylvain Gugger committed
1624
            decoder_input_ids = self._shift_right(labels)
1625

1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

1637
        # Decode
1638
1639
1640
1641
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1642
            past_key_values=past_key_values,
1643
1644
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1645
            head_mask=decoder_head_mask,
1646
            cross_attn_head_mask=cross_attn_head_mask,
1647
            use_cache=use_cache,
1648
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1649
            output_hidden_states=output_hidden_states,
1650
            return_dict=return_dict,
1651
        )
1652
1653

        sequence_output = decoder_outputs[0]
Patrick von Platen's avatar
Patrick von Platen committed
1654

1655
1656
1657
1658
1659
1660
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

Patrick von Platen's avatar
Patrick von Platen committed
1661
1662
1663
        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1664
            sequence_output = sequence_output * (self.model_dim**-0.5)
Patrick von Platen's avatar
Patrick von Platen committed
1665

thomwolf's avatar
thomwolf committed
1666
        lm_logits = self.lm_head(sequence_output)
thomwolf's avatar
thomwolf committed
1667

1668
        loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1669
        if labels is not None:
Lysandre's avatar
Lysandre committed
1670
            loss_fct = CrossEntropyLoss(ignore_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1671
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1672
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
thomwolf's avatar
thomwolf committed
1673

1674
        if not return_dict:
1675
1676
1677
1678
1679
1680
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
1681
            past_key_values=decoder_outputs.past_key_values,
1682
1683
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1684
            cross_attentions=decoder_outputs.cross_attentions,
1685
1686
1687
1688
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
1689

1690
    def prepare_inputs_for_generation(
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
        self,
        input_ids,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs
1701
    ):
1702
1703
1704
1705
1706

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

1707
1708
        return {
            "decoder_input_ids": input_ids,
1709
            "past_key_values": past,
1710
1711
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
1712
1713
1714
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
1715
            "use_cache": use_cache,
1716
1717
        }

1718
1719
1720
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

1721
    def _reorder_cache(self, past, beam_idx):
1722
1723
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
1724
        if past is None:
1725
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1726
1727
1728
            return past

        reordered_decoder_past = ()
1729
        for layer_past_states in past:
1730
1731
1732
1733
1734
1735
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
1736
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1737
1738
1739
1740
1741
1742
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1743
        return reordered_decoder_past
1744
1745
1746


@add_start_docstrings(
1747
    "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
    T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
    authorized_missing_keys = [
        r"encoder\.embed_tokens\.weight",
    ]

    def __init__(self, config: T5Config):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

1764
1765
        # Initialize weights and apply final processing
        self.post_init()
1766

Lysandre Debut's avatar
Lysandre Debut committed
1767
1768
1769
1770
        # Model parallel
        self.model_parallel = False
        self.device_map = None

1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)

    def get_encoder(self):
        return self.encoder

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
1812
1813
1814
1815
1816
1817
1818
1819
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
1820
1821
1822
        r"""
        Returns:

1823
        Example:
1824

1825
1826
        ```python
        >>> from transformers import T5Tokenizer, T5EncoderModel
Sylvain Gugger's avatar
Sylvain Gugger committed
1827
1828
1829
1830
1831
1832

        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5EncoderModel.from_pretrained("t5-small")
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        >>> ).input_ids  # Batch size 1
1833
1834
1835
        >>> outputs = model(input_ids=input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        return encoder_outputs