modeling_t5.py 78.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
""" PyTorch T5 model."""
thomwolf's avatar
thomwolf committed
16
17


Aymeric Augustin's avatar
Aymeric Augustin committed
18
import copy
thomwolf's avatar
thomwolf committed
19
20
import math
import os
21
import warnings
thomwolf's avatar
thomwolf committed
22
23

import torch
Aymeric Augustin's avatar
Aymeric Augustin committed
24
from torch import nn
25
from torch.nn import CrossEntropyLoss
26
from torch.utils.checkpoint import checkpoint
thomwolf's avatar
thomwolf committed
27

Patrick von Platen's avatar
Patrick von Platen committed
28
from ...activations import ACT2FN
Sylvain Gugger's avatar
Sylvain Gugger committed
29
from ...file_utils import (
30
31
32
    DUMMY_INPUTS,
    DUMMY_MASK,
    add_start_docstrings,
33
    add_start_docstrings_to_model_forward,
34
    is_torch_fx_proxy,
35
36
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
37
from ...modeling_outputs import (
38
39
40
41
42
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
43
44
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
45
from ...utils.model_parallel_utils import assert_device_map, get_device_map
Sylvain Gugger's avatar
Sylvain Gugger committed
46
from .configuration_t5 import T5Config
Aymeric Augustin's avatar
Aymeric Augustin committed
47

thomwolf's avatar
thomwolf committed
48

Lysandre Debut's avatar
Lysandre Debut committed
49
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
50

51
_CONFIG_FOR_DOC = "T5Config"
52
_TOKENIZER_FOR_DOC = "T5Tokenizer"
53
_CHECKPOINT_FOR_DOC = "t5-small"
54

thomwolf's avatar
thomwolf committed
55
####################################################
56
# This dict contains ids and associated url
thomwolf's avatar
thomwolf committed
57
58
# for the pretrained weights provided with the models
####################################################
59
60
61
62
63
64
65
66
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "t5-small",
    "t5-base",
    "t5-large",
    "t5-3b",
    "t5-11b",
    # See all T5 models at https://huggingface.co/models?filter=t5
]
thomwolf's avatar
thomwolf committed
67

68

thomwolf's avatar
thomwolf committed
69
70
71
72
73
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
Lysandre's avatar
Lysandre committed
74
    """Load tf checkpoints in a pytorch model."""
thomwolf's avatar
thomwolf committed
75
76
    try:
        import re
77

thomwolf's avatar
thomwolf committed
78
79
80
        import numpy as np
        import tensorflow as tf
    except ImportError:
81
82
83
84
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
85
86
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
87
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
thomwolf's avatar
thomwolf committed
88
89
90
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
91
    tf_weights = {}
thomwolf's avatar
thomwolf committed
92
    for name, shape in init_vars:
93
        logger.info(f"Loading TF weight {name} with shape {shape}")
thomwolf's avatar
thomwolf committed
94
95
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
96
        tf_weights[name] = array
thomwolf's avatar
thomwolf committed
97

98
    for txt_name in names:
99
        name = txt_name.split("/")
thomwolf's avatar
thomwolf committed
100
101
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
102
103
104
105
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
106
            logger.info(f"Skipping {'/'.join(name)}")
107
108
            tf_weights.pop(txt_name, None)
            continue
109
        if "_slot_" in name[-1]:
110
            logger.info(f"Skipping {'/'.join(name)}")
111
            tf_weights.pop(txt_name, None)
thomwolf's avatar
thomwolf committed
112
113
            continue
        pointer = model
114
        array = tf_weights[txt_name]
Patrick von Platen's avatar
Patrick von Platen committed
115

thomwolf's avatar
thomwolf committed
116
        for m_name in name:
117
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
118
                scope_names = re.split(r"_(\d+)", m_name)
thomwolf's avatar
thomwolf committed
119
            else:
120
121
                scope_names = [m_name]
            if scope_names[0] in ["kernel", "scale", "embedding"]:
122
                pointer = getattr(pointer, "weight")
Patrick von Platen's avatar
Patrick von Platen committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            elif scope_names[0] == "self_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[0]
            elif scope_names[0] == "enc_dec_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[1]
            elif scope_names[0] == "dense_relu_dense":
                pointer = getattr(pointer, "layer")
                pointer = pointer[2]
            elif scope_names[0] == "rms_norm":
                if hasattr(pointer, "layer_norm"):
                    pointer = getattr(pointer, "layer_norm")
                elif hasattr(pointer, "final_layer_norm"):
                    pointer = getattr(pointer, "final_layer_norm")
137
138
139
140
141
142
            elif scope_names[0] == "scale":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
Patrick von Platen's avatar
Patrick von Platen committed
143
144
145
146
            elif scope_names[0] == "decoder" and name[1] == "logits":
                continue
            elif scope_names[0] == "logits":
                pointer = getattr(pointer, "lm_head")
Patrick von Platen's avatar
Patrick von Platen committed
147
148
149
            elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
                pointer = getattr(pointer, f"wi_{scope_names[1]}")
                continue
thomwolf's avatar
thomwolf committed
150
151
            else:
                try:
152
                    pointer = getattr(pointer, scope_names[0])
thomwolf's avatar
thomwolf committed
153
                except AttributeError:
154
                    logger.info(f"Skipping {'/'.join(name)}")
thomwolf's avatar
thomwolf committed
155
                    continue
156
157
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
158
                pointer = pointer[num]
159
        if scope_names[0] not in ["kernel", "scale", "embedding"]:
160
            pointer = getattr(pointer, "weight")
161
        if scope_names[0] != "embedding":
162
            logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
thomwolf's avatar
thomwolf committed
163
164
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
165
166
167
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
thomwolf's avatar
thomwolf committed
168
169
170
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
171
        logger.info(f"Initialize PyTorch weight {name}")
172
173
174
        pointer.data = torch.from_numpy(array.astype(np.float32))
        tf_weights.pop(txt_name, None)

175
    logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
thomwolf's avatar
thomwolf committed
176
177
178
179
180
181
    return model


####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
182
# - PreTrainedModel for the models (it-self a sub-class of nn.Module)
thomwolf's avatar
thomwolf committed
183
####################################################
184
PARALLELIZE_DOCSTRING = r"""
Stas Bekman's avatar
Stas Bekman committed
185
186
    This is an experimental feature and is a subject to change at a moment's notice.

187
188
189
190
    Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
    it will evenly distribute blocks across all devices.

    Args:
191
        device_map (`Dict[int, list]`, optional, defaults to None):
192
193
194
195
196
197
198
199
200
201
202
            A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
            automatically mapped to the first device (for esoteric reasons). That means that the first device should
            have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
            following number of attention modules:

                - t5-small: 6
                - t5-base: 12
                - t5-large: 24
                - t5-3b: 24
                - t5-11b: 24

203
    Example:
204

205
206
    ```python
    # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
Sylvain Gugger's avatar
Sylvain Gugger committed
207
208
209
210
211
212
213
    model = T5ForConditionalGeneration.from_pretrained("t5-3b")
    device_map = {
        0: [0, 1, 2],
        1: [3, 4, 5, 6, 7, 8, 9],
        2: [10, 11, 12, 13, 14, 15, 16],
        3: [17, 18, 19, 20, 21, 22, 23],
    }
214
215
    model.parallelize(device_map)
    ```
216
217
218
219
"""
DEPARALLELIZE_DOCSTRING = r"""
    Moves the model to cpu from a model parallel state.

220
    Example:
221

222
223
    ```python
    # On a 4 GPU machine with t5-3b:
Sylvain Gugger's avatar
Sylvain Gugger committed
224
225
226
227
228
229
230
231
232
    model = T5ForConditionalGeneration.from_pretrained("t5-3b")
    device_map = {
        0: [0, 1, 2],
        1: [3, 4, 5, 6, 7, 8, 9],
        2: [10, 11, 12, 13, 14, 15, 16],
        3: [17, 18, 19, 20, 21, 22, 23],
    }
    model.parallelize(device_map)  # Splits the model across several devices
    model.deparallelize()  # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
233
    ```
234
"""
thomwolf's avatar
thomwolf committed
235

236

thomwolf's avatar
thomwolf committed
237
238
class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
Sylvain Gugger's avatar
Sylvain Gugger committed
239
        """
240
        Construct a layernorm module in the T5 style No bias and no subtraction of mean.
thomwolf's avatar
thomwolf committed
241
        """
Julien Chaumond's avatar
Julien Chaumond committed
242
        super().__init__()
thomwolf's avatar
thomwolf committed
243
244
245
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

246
    def forward(self, hidden_states):
247
        # layer norm should always be calculated in float32
248
249
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
250

251
252
253
254
        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

255
        return self.weight * hidden_states
thomwolf's avatar
thomwolf committed
256
257


thomwolf's avatar
thomwolf committed
258
class T5DenseReluDense(nn.Module):
thomwolf's avatar
thomwolf committed
259
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
260
        super().__init__()
thomwolf's avatar
thomwolf committed
261
262
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
thomwolf's avatar
thomwolf committed
263
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
264
265

    def forward(self, hidden_states):
266
        hidden_states = self.wi(hidden_states)
267
        hidden_states = nn.functional.relu(hidden_states)
268
269
270
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
271
272


Patrick von Platen's avatar
Patrick von Platen committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class T5DenseGatedGeluDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.gelu_act = ACT2FN["gelu_new"]

    def forward(self, hidden_states):
        hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states


thomwolf's avatar
thomwolf committed
291
292
class T5LayerFF(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
293
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
294
295
296
297
298
299
300
301
302
        if config.feed_forward_proj == "relu":
            self.DenseReluDense = T5DenseReluDense(config)
        elif config.feed_forward_proj == "gated-gelu":
            self.DenseReluDense = T5DenseGatedGeluDense(config)
        else:
            raise ValueError(
                f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
            )

thomwolf's avatar
thomwolf committed
303
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
304
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
305
306

    def forward(self, hidden_states):
307
308
309
310
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
311
312
313


class T5Attention(nn.Module):
314
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
315
        super().__init__()
thomwolf's avatar
thomwolf committed
316
        self.is_decoder = config.is_decoder
thomwolf's avatar
thomwolf committed
317
        self.has_relative_attention_bias = has_relative_attention_bias
thomwolf's avatar
thomwolf committed
318
319

        self.relative_attention_num_buckets = config.relative_attention_num_buckets
320
        self.d_model = config.d_model
321
        self.key_value_proj_dim = config.d_kv
thomwolf's avatar
thomwolf committed
322
323
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
324
        self.inner_dim = self.n_heads * self.key_value_proj_dim
thomwolf's avatar
thomwolf committed
325

326
        # Mesh TensorFlow initialization to avoid scaling before softmax
327
328
329
330
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
thomwolf's avatar
thomwolf committed
331

thomwolf's avatar
thomwolf committed
332
333
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
thomwolf's avatar
thomwolf committed
334
        self.pruned_heads = set()
335
        self.gradient_checkpointing = False
thomwolf's avatar
thomwolf committed
336
337
338
339

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
340
341
342
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
thomwolf's avatar
thomwolf committed
343
344
345
346
347
348
349
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
350
        self.inner_dim = self.key_value_proj_dim * self.n_heads
thomwolf's avatar
thomwolf committed
351
352
353
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
354
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
thomwolf's avatar
thomwolf committed
355
356
357
358
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

Sylvain Gugger's avatar
Sylvain Gugger committed
359
360
361
362
363
364
365
        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

thomwolf's avatar
thomwolf committed
366
367
368
369
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
370
            max_distance: an integer
Sylvain Gugger's avatar
Sylvain Gugger committed
371

thomwolf's avatar
thomwolf committed
372
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
373
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
thomwolf's avatar
thomwolf committed
374
        """
375
        relative_buckets = 0
thomwolf's avatar
thomwolf committed
376
377
        if bidirectional:
            num_buckets //= 2
378
379
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
thomwolf's avatar
thomwolf committed
380
        else:
381
382
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)
thomwolf's avatar
thomwolf committed
383
384
385

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
386
        is_small = relative_position < max_exact
thomwolf's avatar
thomwolf committed
387
388

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
389
390
391
392
        relative_postion_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
393
        ).to(torch.long)
394
395
396
        relative_postion_if_large = torch.min(
            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
        )
thomwolf's avatar
thomwolf committed
397

398
399
        relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
        return relative_buckets
thomwolf's avatar
thomwolf committed
400

401
    def compute_bias(self, query_length, key_length):
Patrick von Platen's avatar
Patrick von Platen committed
402
        """Compute binned relative position bias"""
403
404
405
406
407
408
        context_position = torch.arange(
            query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
        )[:, None]
        memory_position = torch.arange(
            key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
        )[None, :]
409
410
411
412
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
413
414
            num_buckets=self.relative_attention_num_buckets,
        )
415
416
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
thomwolf's avatar
thomwolf committed
417
418
        return values

419
420
    def forward(
        self,
421
        hidden_states,
422
        mask=None,
423
        key_value_states=None,
424
        position_bias=None,
425
        past_key_value=None,
426
        layer_head_mask=None,
427
        query_length=None,
428
        use_cache=False,
429
        output_attentions=False,
430
    ):
thomwolf's avatar
thomwolf committed
431
        """
432
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
thomwolf's avatar
thomwolf committed
433
        """
434
435
436
437
438
439
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length
440

441
        if past_key_value is not None:
442
            assert (
443
                len(past_key_value) == 2
444
            ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
445
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
446

447
        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
thomwolf's avatar
thomwolf committed
448

449
        def shape(states):
Patrick von Platen's avatar
Patrick von Platen committed
450
            """projection"""
451
452
453
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
Patrick von Platen's avatar
Patrick von Platen committed
454
            """reshape"""
455
456
457
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
Patrick von Platen's avatar
Patrick von Platen committed
458
            """projects hidden states correctly to key/query states"""
459
460
461
462
463
464
465
466
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
467

468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
thomwolf's avatar
thomwolf committed
488

489
        # compute scores
Abel's avatar
Abel committed
490
        scores = torch.matmul(
491
492
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
thomwolf's avatar
thomwolf committed
493
494

        if position_bias is None:
thomwolf's avatar
thomwolf committed
495
            if not self.has_relative_attention_bias:
496
497
498
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
499
                if self.gradient_checkpointing and self.training:
500
                    position_bias.requires_grad = True
501
502
            else:
                position_bias = self.compute_bias(real_seq_length, key_length)
503
504
505

            # if key and values are already calculated
            # we want only the last query position bias
506
            if past_key_value is not None:
507
                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
508

thomwolf's avatar
thomwolf committed
509
            if mask is not None:
510
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
511

thomwolf's avatar
thomwolf committed
512
        scores += position_bias
513
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
514
515
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
516
        attn_weights = nn.functional.dropout(
517
518
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
519
520

        # Mask heads if we want to
521
522
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask
thomwolf's avatar
thomwolf committed
523

524
525
        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)
thomwolf's avatar
thomwolf committed
526

527
528
        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
529

530
        if output_attentions:
531
            outputs = outputs + (attn_weights,)
thomwolf's avatar
thomwolf committed
532
        return outputs
thomwolf's avatar
thomwolf committed
533
534
535


class T5LayerSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
536
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
537
        super().__init__()
538
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
539
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
540
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
541

542
    def forward(
543
544
545
546
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
547
        layer_head_mask=None,
548
        past_key_value=None,
549
        use_cache=False,
550
        output_attentions=False,
551
    ):
552
        normed_hidden_states = self.layer_norm(hidden_states)
553
        attention_output = self.SelfAttention(
554
            normed_hidden_states,
555
556
            mask=attention_mask,
            position_bias=position_bias,
557
            layer_head_mask=layer_head_mask,
558
            past_key_value=past_key_value,
559
            use_cache=use_cache,
560
            output_attentions=output_attentions,
561
        )
562
563
        hidden_states = hidden_states + self.dropout(attention_output[0])
        outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them
thomwolf's avatar
thomwolf committed
564
        return outputs
thomwolf's avatar
thomwolf committed
565
566


thomwolf's avatar
thomwolf committed
567
class T5LayerCrossAttention(nn.Module):
568
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
569
        super().__init__()
570
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
thomwolf's avatar
thomwolf committed
571
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
572
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
573

574
575
576
    def forward(
        self,
        hidden_states,
577
        key_value_states,
578
579
        attention_mask=None,
        position_bias=None,
580
        layer_head_mask=None,
581
        past_key_value=None,
582
        use_cache=False,
583
        query_length=None,
584
        output_attentions=False,
585
    ):
586
        normed_hidden_states = self.layer_norm(hidden_states)
587
        attention_output = self.EncDecAttention(
588
            normed_hidden_states,
589
            mask=attention_mask,
590
            key_value_states=key_value_states,
591
            position_bias=position_bias,
592
            layer_head_mask=layer_head_mask,
593
            past_key_value=past_key_value,
594
            use_cache=use_cache,
595
            query_length=query_length,
596
            output_attentions=output_attentions,
597
        )
598
        layer_output = hidden_states + self.dropout(attention_output[0])
thomwolf's avatar
thomwolf committed
599
600
601
602
603
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


class T5Block(nn.Module):
thomwolf's avatar
thomwolf committed
604
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
605
        super().__init__()
thomwolf's avatar
thomwolf committed
606
        self.is_decoder = config.is_decoder
607
608
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
thomwolf's avatar
thomwolf committed
609
        if self.is_decoder:
610
            self.layer.append(T5LayerCrossAttention(config))
611
612

        self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
613

614
615
616
617
618
619
620
621
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
622
        layer_head_mask=None,
623
        cross_attn_layer_head_mask=None,
624
        past_key_value=None,
625
        use_cache=False,
626
        output_attentions=False,
627
        return_dict=True,
628
    ):
629

630
631
632
        if past_key_value is not None:
            assert self.is_decoder, "Only decoder can use `past_key_values`"
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
633

634
635
636
            if len(past_key_value) != expected_num_past_key_values:
                raise ValueError(
                    f"There should be {expected_num_past_key_values} past states. "
637
                    f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
638
639
                    f"Got {len(past_key_value)} past key / value states"
                )
640

641
642
            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
643
        else:
644
            self_attn_past_key_value, cross_attn_past_key_value = None, None
645

646
        self_attention_outputs = self.layer[0](
647
648
649
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
650
            layer_head_mask=layer_head_mask,
651
            past_key_value=self_attn_past_key_value,
652
            use_cache=use_cache,
653
            output_attentions=output_attentions,
654
        )
655
656
657
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

Suraj Patil's avatar
Suraj Patil committed
658
        # clamp inf values to enable fp16 training
659
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
660
661
662
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

663
664
        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
665
666
667
668
669
670
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None
thomwolf's avatar
thomwolf committed
671

672
673
            cross_attention_outputs = self.layer[1](
                hidden_states,
674
                key_value_states=encoder_hidden_states,
675
676
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
677
                layer_head_mask=cross_attn_layer_head_mask,
678
                past_key_value=cross_attn_past_key_value,
679
                query_length=query_length,
680
                use_cache=use_cache,
681
                output_attentions=output_attentions,
682
            )
thomwolf's avatar
thomwolf committed
683
            hidden_states = cross_attention_outputs[0]
684
685
686

            # clamp inf values to enable fp16 training
            if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
687
688
689
                clamp_value = torch.finfo(hidden_states.dtype).max - 1000
                hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

690
691
692
693
694
695
696
697
698
            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)
699
700
701

        # clamp inf values to enable fp16 training
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
702
703
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
704

705
        outputs = (hidden_states,)
thomwolf's avatar
thomwolf committed
706

707
708
709
710
711
        if use_cache:
            outputs = outputs + (present_key_value_state,) + attention_outputs
        else:
            outputs = outputs + attention_outputs

712
        return outputs  # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
thomwolf's avatar
thomwolf committed
713
714


thomwolf's avatar
thomwolf committed
715
class T5PreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
716
717
718
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
thomwolf's avatar
thomwolf committed
719
    """
720

thomwolf's avatar
thomwolf committed
721
722
723
    config_class = T5Config
    load_tf_weights = load_tf_weights_in_t5
    base_model_prefix = "transformer"
724
    is_parallelizable = True
725
    supports_gradient_checkpointing = True
thomwolf's avatar
thomwolf committed
726

727
728
729
730
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
731
732
        dummy_inputs = {
            "decoder_input_ids": input_ids,
733
            "input_ids": input_ids,
734
735
            "decoder_attention_mask": input_mask,
        }
736
737
        return dummy_inputs

thomwolf's avatar
thomwolf committed
738
    def _init_weights(self, module):
Patrick von Platen's avatar
Patrick von Platen committed
739
        """Initialize the weights"""
740
        factor = self.config.initializer_factor  # Used for testing weights initialization
thomwolf's avatar
thomwolf committed
741
        if isinstance(module, T5LayerNorm):
742
            module.weight.data.fill_(factor * 1.0)
743
        elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
744
745
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
746
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
747
748
749
750
        elif isinstance(module, T5DenseReluDense):
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
751
752
            module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi, "bias") and module.wi.bias is not None:
753
                module.wi.bias.data.zero_()
754
755
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
756
                module.wo.bias.data.zero_()
Patrick von Platen's avatar
Patrick von Platen committed
757
758
759
760
761
762
763
764
765
766
        elif isinstance(module, T5DenseGatedGeluDense):
            module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
                module.wi_0.bias.data.zero_()
            module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
                module.wi_1.bias.data.zero_()
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
                module.wo.bias.data.zero_()
767
768
769
770
        elif isinstance(module, T5Attention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
771
            key_value_proj_dim = self.config.d_kv
772
            n_heads = self.config.num_heads
773
            module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
774
775
            module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
            module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
776
            module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
777
            if module.has_relative_attention_bias:
778
                module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
thomwolf's avatar
thomwolf committed
779

780
781
782
783
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (T5Attention, T5Stack)):
            module.gradient_checkpointing = value

784
785
786
787
788
789
790
791
792
    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

        assert (
            decoder_start_token_id is not None
        ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"

        # shift inputs to the right
793
794
795
796
797
798
799
800
        if is_torch_fx_proxy(input_ids):
            # Item assignment is not supported natively for proxies.
            shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
            shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        else:
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
            shifted_input_ids[..., 0] = decoder_start_token_id
801
802

        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
Sylvain Gugger's avatar
Sylvain Gugger committed
803
        # replace possible -100 values in labels by `pad_token_id`
804
805
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

806
        assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
807
808
809

        return shifted_input_ids

thomwolf's avatar
thomwolf committed
810
811

class T5Stack(T5PreTrainedModel):
812
    def __init__(self, config, embed_tokens=None):
Julien Chaumond's avatar
Julien Chaumond committed
813
        super().__init__(config)
814
815

        self.embed_tokens = embed_tokens
thomwolf's avatar
thomwolf committed
816
817
        self.is_decoder = config.is_decoder

818
819
820
        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
        )
thomwolf's avatar
thomwolf committed
821
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
822
823
        self.dropout = nn.Dropout(config.dropout_rate)

824
825
        # Initialize weights and apply final processing
        self.post_init()
826
827
828
        # Model parallel
        self.model_parallel = False
        self.device_map = None
829
        self.gradient_checkpointing = False
830
831
832
833
834

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        # Check validity of device_map
        self.device_map = (
835
            get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
        )
        assert_device_map(self.device_map, len(self.block))
        self.model_parallel = True
        self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
        self.last_device = "cuda:" + str(max(self.device_map.keys()))
        # Load onto devices
        for k, v in self.device_map.items():
            for layer in v:
                cuda_device = "cuda:" + str(k)
                self.block[layer] = self.block[layer].to(cuda_device)

        # Set embed_tokens to first layer
        self.embed_tokens = self.embed_tokens.to(self.first_device)
        # Set final layer norm to last device
        self.final_layer_norm = self.final_layer_norm.to(self.last_device)

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.model_parallel = False
        self.device_map = None
        self.first_device = "cpu"
        self.last_device = "cpu"
        for i in range(len(self.block)):
            self.block[i] = self.block[i].to("cpu")
        self.embed_tokens = self.embed_tokens.to("cpu")
        self.final_layer_norm = self.final_layer_norm.to("cpu")
        torch.cuda.empty_cache()
thomwolf's avatar
thomwolf committed
863

864
865
866
867
868
869
    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

870
871
    def forward(
        self,
872
        input_ids=None,
873
874
875
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
876
        inputs_embeds=None,
877
        head_mask=None,
878
        cross_attn_head_mask=None,
879
        past_key_values=None,
880
        use_cache=None,
881
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
882
        output_hidden_states=None,
883
        return_dict=None,
884
    ):
885
886
887
888
        # Model parallel
        if self.model_parallel:
            torch.cuda.set_device(self.first_device)
            self.embed_tokens = self.embed_tokens.to(self.first_device)
889
        use_cache = use_cache if use_cache is not None else self.config.use_cache
890
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
891
892
893
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
894
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
895

896
        if input_ids is not None and inputs_embeds is not None:
897
898
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
Jonathan Chang's avatar
Jonathan Chang committed
899
                f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
900
            )
901
902
903
904
905
906
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
907
            err_msg_prefix = "decoder_" if self.is_decoder else ""
Jonathan Chang's avatar
Jonathan Chang committed
908
            raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
909
910

        if inputs_embeds is None:
911
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
912
913
914
915
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape

916
917
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
918

919
        if use_cache is True:
Stas Bekman's avatar
Stas Bekman committed
920
            assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
921

thomwolf's avatar
thomwolf committed
922
        if attention_mask is None:
923
924
            attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
thomwolf's avatar
thomwolf committed
925
            encoder_seq_length = encoder_hidden_states.shape[1]
926
927
928
            encoder_attention_mask = torch.ones(
                batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
            )
thomwolf's avatar
thomwolf committed
929

930
931
932
        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)
933

lexhuismans's avatar
lexhuismans committed
934
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
thomwolf's avatar
thomwolf committed
935
        # ourselves in which case we just need to make it broadcastable to all heads.
936
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
thomwolf's avatar
thomwolf committed
937

938
939
940
941
942
943
944
        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
945
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
thomwolf's avatar
thomwolf committed
946
947
        else:
            encoder_extended_attention_mask = None
thomwolf's avatar
thomwolf committed
948
949

        # Prepare head mask if needed
950
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
951
        cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
952
953
954
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
955
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
thomwolf's avatar
thomwolf committed
956
        position_bias = None
thomwolf's avatar
thomwolf committed
957
        encoder_decoder_position_bias = None
thomwolf's avatar
thomwolf committed
958

959
        hidden_states = self.dropout(inputs_embeds)
960

961
        for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
962
            layer_head_mask = head_mask[i]
963
            cross_attn_layer_head_mask = cross_attn_head_mask[i]
964
965
966
967
968
969
970
971
972
973
974
975
976
977
            # Model parallel
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                # Ensure that attention_mask is always on the same device as hidden_states
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if position_bias is not None:
                    position_bias = position_bias.to(hidden_states.device)
                if encoder_hidden_states is not None:
                    encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
                if encoder_extended_attention_mask is not None:
                    encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
                if encoder_decoder_position_bias is not None:
                    encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
978
979
                if layer_head_mask is not None:
                    layer_head_mask = layer_head_mask.to(hidden_states.device)
980
981
                if cross_attn_layer_head_mask is not None:
                    cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
Joseph Liu's avatar
Joseph Liu committed
982
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
983
984
                all_hidden_states = all_hidden_states + (hidden_states,)

985
            if self.gradient_checkpointing and self.training:
986
                if use_cache:
987
                    logger.warning(
988
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return tuple(module(*inputs, use_cache, output_attentions))

                    return custom_forward

                layer_outputs = checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    extended_attention_mask,
                    position_bias,
                    encoder_hidden_states,
                    encoder_extended_attention_mask,
                    encoder_decoder_position_bias,
                    layer_head_mask,
                    cross_attn_layer_head_mask,
                    None,  # past_key_value is always None with gradient checkpointing
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_extended_attention_mask,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    layer_head_mask=layer_head_mask,
                    cross_attn_layer_head_mask=cross_attn_layer_head_mask,
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

thomwolf's avatar
thomwolf committed
1025
            # layer_outputs is a tuple with:
1026
            # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1027
1028
            if use_cache is False:
                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1029

1030
            hidden_states, present_key_value_state = layer_outputs[:2]
1031

1032
            # We share the position biases between the layers - the first layer store them
1033
1034
            # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
            # (cross-attention position bias), (cross-attention weights)
1035
1036
1037
            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1038
            # append next layer key value states
1039
1040
            if use_cache:
                present_key_value_states = present_key_value_states + (present_key_value_state,)
thomwolf's avatar
thomwolf committed
1041

1042
            if output_attentions:
1043
                all_attentions = all_attentions + (layer_outputs[3],)
1044
                if self.is_decoder:
1045
                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
thomwolf's avatar
thomwolf committed
1046

1047
1048
1049
1050
1051
1052
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

thomwolf's avatar
thomwolf committed
1053
        hidden_states = self.final_layer_norm(hidden_states)
thomwolf's avatar
thomwolf committed
1054
        hidden_states = self.dropout(hidden_states)
thomwolf's avatar
thomwolf committed
1055
1056

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
1057
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
1058
1059
            all_hidden_states = all_hidden_states + (hidden_states,)

1060
        if not return_dict:
1061
1062
            return tuple(
                v
1063
1064
1065
1066
1067
1068
1069
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
1070
1071
                if v is not None
            )
1072
        return BaseModelOutputWithPastAndCrossAttentions(
1073
1074
1075
1076
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
1077
            cross_attentions=all_cross_attentions,
1078
        )
thomwolf's avatar
thomwolf committed
1079
1080


1081
T5_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1082

Sylvain Gugger's avatar
Sylvain Gugger committed
1083
1084
1085
1086
    The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
    Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
    Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
    text-to-text denoising generative setting.
thomwolf's avatar
thomwolf committed
1087

Sylvain Gugger's avatar
Sylvain Gugger committed
1088
1089
1090
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)
Sylvain Gugger's avatar
Sylvain Gugger committed
1091

Sylvain Gugger's avatar
Sylvain Gugger committed
1092
1093
1094
    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.
thomwolf's avatar
thomwolf committed
1095
1096

    Parameters:
1097
        config ([`T5Config`]): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
1098
            Initializing with a config file does not load the weights associated with the model, only the
Sylvain Gugger's avatar
Sylvain Gugger committed
1099
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
thomwolf's avatar
thomwolf committed
1100
1101
1102
"""

T5_INPUTS_DOCSTRING = r"""
Patrick von Platen's avatar
Patrick von Platen committed
1103
    Args:
1104
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1105
1106
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.
Sylvain Gugger's avatar
Sylvain Gugger committed
1107

Sylvain Gugger's avatar
Sylvain Gugger committed
1108
1109
            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for detail.
Sylvain Gugger's avatar
Sylvain Gugger committed
1110

1111
            [What are input IDs?](../glossary#input-ids)
Sylvain Gugger's avatar
Sylvain Gugger committed
1112

1113
1114
1115
            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Sylvain Gugger's avatar
Sylvain Gugger committed
1116
1117

            - 1 for tokens that are **not masked**,
1118
            - 0 for tokens that are **masked**.
Sylvain Gugger's avatar
Sylvain Gugger committed
1119

1120
1121
            [What are attention masks?](../glossary#attention-mask)
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1122
1123
            Indices of decoder input sequence tokens in the vocabulary.

Sylvain Gugger's avatar
Sylvain Gugger committed
1124
1125
            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
1126

1127
            [What are decoder input IDs?](../glossary#decoder-input-ids)
1128

Sylvain Gugger's avatar
Sylvain Gugger committed
1129
1130
            T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
            is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1131

Sylvain Gugger's avatar
Sylvain Gugger committed
1132
1133
            To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
            Training](./t5#training).
1134
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1135
1136
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
1137
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1138
1139
            Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
            1]`:
1140
1141
1142
1143

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1144
        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1145
1146
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
            1]`:
1147
1148
1149
1150

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1151
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1152
                Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1153
                `[0, 1]`:
1154
1155
1156
1157

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

1158
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1159
1160
1161
            Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
            the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1162
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1163
1164
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

Sylvain Gugger's avatar
Sylvain Gugger committed
1165
1166
1167
            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1168
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1169
1170
1171
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
1172
1173
        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
Sylvain Gugger's avatar
Sylvain Gugger committed
1174
1175
            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
            input (see `past_key_values`). This is useful if you want more control over how to convert
1176
            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
Sylvain Gugger's avatar
Sylvain Gugger committed
1177

Sylvain Gugger's avatar
Sylvain Gugger committed
1178
1179
            If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
            of `inputs_embeds`.
Sylvain Gugger's avatar
Sylvain Gugger committed
1180

1181
        use_cache (`bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1182
1183
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1184

1185
1186
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
Sylvain Gugger's avatar
Sylvain Gugger committed
1187
            tensors for more detail.
1188
1189
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
Sylvain Gugger's avatar
Sylvain Gugger committed
1190
            more detail.
1191
1192
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
thomwolf's avatar
thomwolf committed
1193
1194
"""

1195
1196
T5_ENCODER_INPUTS_DOCSTRING = r"""
    Args:
1197
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1198
1199
1200
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.

Sylvain Gugger's avatar
Sylvain Gugger committed
1201
1202
            Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for detail.
1203

1204
1205
1206
            To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1207
1208
1209
1210

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

1211
1212
1213
            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1214
1215
1216
1217

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1218
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1219
1220
1221
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
1222
1223
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1224
            tensors for more detail.
1225
1226
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1227
            more detail.
1228
1229
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
1230
1231
"""

1232
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1233
1234
1235
1236
1237
1238
1239
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""

1240
1241

@add_start_docstrings(
NielsRogge's avatar
NielsRogge committed
1242
    "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1243
1244
    T5_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1245
class T5Model(T5PreTrainedModel):
1246
    _keys_to_ignore_on_load_missing = [
1247
1248
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1249
1250
    ]
    _keys_to_ignore_on_load_unexpected = [
1251
1252
1253
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

1254
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
1255
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1256
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
thomwolf's avatar
thomwolf committed
1257
1258

        encoder_config = copy.deepcopy(config)
1259
        encoder_config.is_decoder = False
1260
        encoder_config.use_cache = False
1261
        encoder_config.is_encoder_decoder = False
1262
        self.encoder = T5Stack(encoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1263

thomwolf's avatar
thomwolf committed
1264
1265
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1266
        decoder_config.is_encoder_decoder = False
1267
        decoder_config.num_layers = config.num_decoder_layers
1268
        self.decoder = T5Stack(decoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1269

1270
1271
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1272

1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

thomwolf's avatar
thomwolf committed
1299
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
1300
        return self.shared
thomwolf's avatar
thomwolf committed
1301
1302

    def set_input_embeddings(self, new_embeddings):
thomwolf's avatar
thomwolf committed
1303
        self.shared = new_embeddings
1304
1305
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
thomwolf's avatar
thomwolf committed
1306

1307
1308
1309
1310
1311
1312
    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

thomwolf's avatar
thomwolf committed
1313
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
1314
1315
1316
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
thomwolf's avatar
thomwolf committed
1317
1318
1319
1320
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

1321
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1322
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1323
1324
1325
1326
1327
1328
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1329
1330
        head_mask=None,
        decoder_head_mask=None,
1331
        cross_attn_head_mask=None,
1332
        encoder_outputs=None,
1333
        past_key_values=None,
1334
1335
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1336
        use_cache=None,
1337
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1338
        output_hidden_states=None,
1339
        return_dict=None,
1340
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1341
        r"""
Lysandre's avatar
Lysandre committed
1342
        Returns:
Patrick von Platen's avatar
Patrick von Platen committed
1343

1344
        Example:
1345

1346
1347
        ```python
        >>> from transformers import T5Tokenizer, T5Model
Patrick von Platen's avatar
Patrick von Platen committed
1348

Sylvain Gugger's avatar
Sylvain Gugger committed
1349
1350
        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5Model.from_pretrained("t5-small")
Patrick von Platen's avatar
Patrick von Platen committed
1351

Sylvain Gugger's avatar
Sylvain Gugger committed
1352
1353
1354
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        >>> ).input_ids  # Batch size 1
1355
        >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
Patrick von Platen's avatar
Patrick von Platen committed
1356

1357
1358
1359
1360
        >>> # forward pass
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""
1361
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1362
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
thomwolf's avatar
thomwolf committed
1363

1364
1365
1366
1367
1368
1369
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

thomwolf's avatar
thomwolf committed
1370
        # Encode if needed (training, first prediction pass)
1371
1372
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
1373
1374
1375
1376
1377
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1378
                output_hidden_states=output_hidden_states,
1379
                return_dict=return_dict,
1380
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
1381
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1382
1383
1384
1385
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1386
            )
thomwolf's avatar
thomwolf committed
1387

1388
        hidden_states = encoder_outputs[0]
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
thomwolf's avatar
thomwolf committed
1401

1402
1403
1404
1405
1406
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1407
            past_key_values=past_key_values,
1408
1409
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1410
            head_mask=decoder_head_mask,
1411
            cross_attn_head_mask=cross_attn_head_mask,
1412
            use_cache=use_cache,
1413
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1414
            output_hidden_states=output_hidden_states,
1415
            return_dict=return_dict,
1416
        )
thomwolf's avatar
thomwolf committed
1417

1418
        if not return_dict:
1419
1420
1421
1422
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
1423
            past_key_values=decoder_outputs.past_key_values,
1424
1425
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1426
            cross_attentions=decoder_outputs.cross_attentions,
1427
1428
1429
1430
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1431
1432


Sylvain Gugger's avatar
Sylvain Gugger committed
1433
@add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
1434
class T5ForConditionalGeneration(T5PreTrainedModel):
1435
    _keys_to_ignore_on_load_missing = [
1436
1437
1438
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1439
1440
    ]
    _keys_to_ignore_on_load_unexpected = [
1441
1442
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]
1443

thomwolf's avatar
thomwolf committed
1444
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1445
        super().__init__(config)
1446
        self.model_dim = config.d_model
thomwolf's avatar
thomwolf committed
1447

1448
1449
1450
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
1451
        encoder_config.is_decoder = False
1452
        encoder_config.use_cache = False
1453
        encoder_config.is_encoder_decoder = False
1454
        self.encoder = T5Stack(encoder_config, self.shared)
1455
1456
1457

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1458
        decoder_config.is_encoder_decoder = False
1459
        decoder_config.num_layers = config.num_decoder_layers
1460
        self.decoder = T5Stack(decoder_config, self.shared)
1461

1462
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
1463

1464
1465
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1466

1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.decoder.first_device)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1495
1496
1497
1498
1499
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
1500
1501
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
1502

1503
1504
1505
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

thomwolf's avatar
thomwolf committed
1506
1507
1508
    def get_output_embeddings(self):
        return self.lm_head

1509
1510
    def get_encoder(self):
        return self.encoder
thomwolf's avatar
thomwolf committed
1511

1512
1513
1514
    def get_decoder(self):
        return self.decoder

1515
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1516
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1517
1518
1519
1520
1521
1522
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1523
1524
        head_mask=None,
        decoder_head_mask=None,
1525
        cross_attn_head_mask=None,
1526
        encoder_outputs=None,
1527
        past_key_values=None,
1528
1529
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1530
1531
        labels=None,
        use_cache=None,
1532
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1533
        output_hidden_states=None,
1534
        return_dict=None,
1535
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1536
        r"""
1537
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1538
1539
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1540
            labels in `[0, ..., config.vocab_size]`
Lysandre's avatar
Lysandre committed
1541
1542
1543

        Returns:

1544
        Examples:
Lysandre's avatar
Lysandre committed
1545

1546
1547
        ```python
        >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
Lysandre's avatar
Lysandre committed
1548

Sylvain Gugger's avatar
Sylvain Gugger committed
1549
1550
        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1551
1552

        >>> # training
Sylvain Gugger's avatar
Sylvain Gugger committed
1553
1554
        >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
        >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1555
1556
1557
1558
1559
        >>> outputs = model(input_ids=input_ids, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits

        >>> # inference
Sylvain Gugger's avatar
Sylvain Gugger committed
1560
1561
1562
        >>> input_ids = tokenizer(
        ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
        >>> ).input_ids  # Batch size 1
1563
1564
1565
1566
        >>> outputs = model.generate(input_ids)
        >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        >>> # studies have shown that owning a dog is good for you.
        ```"""
1567
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1568
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1569

1570
1571
1572
1573
1574
1575
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

1576
        # Encode if needed (training, first prediction pass)
1577
        if encoder_outputs is None:
thomwolf's avatar
thomwolf committed
1578
            # Convert encoder inputs in embeddings if needed
1579
            encoder_outputs = self.encoder(
1580
1581
1582
1583
1584
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1585
                output_hidden_states=output_hidden_states,
1586
                return_dict=return_dict,
1587
            )
1588
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1589
1590
1591
1592
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1593
            )
thomwolf's avatar
thomwolf committed
1594

1595
        hidden_states = encoder_outputs[0]
1596

1597
1598
1599
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

Sylvain Gugger's avatar
Sylvain Gugger committed
1600
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1601
            # get decoder inputs from shifting lm labels to the right
Sylvain Gugger's avatar
Sylvain Gugger committed
1602
            decoder_input_ids = self._shift_right(labels)
1603

1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

1615
        # Decode
1616
1617
1618
1619
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1620
            past_key_values=past_key_values,
1621
1622
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1623
            head_mask=decoder_head_mask,
1624
            cross_attn_head_mask=cross_attn_head_mask,
1625
            use_cache=use_cache,
1626
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1627
            output_hidden_states=output_hidden_states,
1628
            return_dict=return_dict,
1629
        )
1630
1631

        sequence_output = decoder_outputs[0]
Patrick von Platen's avatar
Patrick von Platen committed
1632

1633
1634
1635
1636
1637
1638
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

Patrick von Platen's avatar
Patrick von Platen committed
1639
1640
1641
1642
1643
        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)

thomwolf's avatar
thomwolf committed
1644
        lm_logits = self.lm_head(sequence_output)
thomwolf's avatar
thomwolf committed
1645

1646
        loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1647
        if labels is not None:
Lysandre's avatar
Lysandre committed
1648
            loss_fct = CrossEntropyLoss(ignore_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1649
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1650
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
thomwolf's avatar
thomwolf committed
1651

1652
        if not return_dict:
1653
1654
1655
1656
1657
1658
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
1659
            past_key_values=decoder_outputs.past_key_values,
1660
1661
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1662
            cross_attentions=decoder_outputs.cross_attentions,
1663
1664
1665
1666
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
1667

1668
    def prepare_inputs_for_generation(
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
        self,
        input_ids,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs
1679
    ):
1680
1681
1682
1683
1684

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

1685
1686
        return {
            "decoder_input_ids": input_ids,
1687
            "past_key_values": past,
1688
1689
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
1690
1691
1692
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
1693
            "use_cache": use_cache,
1694
1695
        }

1696
1697
1698
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

1699
    def _reorder_cache(self, past, beam_idx):
1700
1701
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
1702
        if past is None:
1703
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1704
1705
1706
            return past

        reordered_decoder_past = ()
1707
        for layer_past_states in past:
1708
1709
1710
1711
1712
1713
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
1714
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1715
1716
1717
1718
1719
1720
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1721
        return reordered_decoder_past
1722
1723
1724


@add_start_docstrings(
1725
    "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
    T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
    authorized_missing_keys = [
        r"encoder\.embed_tokens\.weight",
    ]

    def __init__(self, config: T5Config):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

1742
1743
        # Initialize weights and apply final processing
        self.post_init()
1744

Lysandre Debut's avatar
Lysandre Debut committed
1745
1746
1747
1748
        # Model parallel
        self.model_parallel = False
        self.device_map = None

1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)

    def get_encoder(self):
        return self.encoder

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:

1801
        Example:
1802

1803
1804
        ```python
        >>> from transformers import T5Tokenizer, T5EncoderModel
Sylvain Gugger's avatar
Sylvain Gugger committed
1805
1806
1807
1808
1809
1810

        >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
        >>> model = T5EncoderModel.from_pretrained("t5-small")
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        >>> ).input_ids  # Batch size 1
1811
1812
1813
        >>> outputs = model(input_ids=input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        return encoder_outputs