modeling_t5.py 79.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch T5 model. """


Aymeric Augustin's avatar
Aymeric Augustin committed
18
import copy
thomwolf's avatar
thomwolf committed
19
20
import math
import os
21
import warnings
thomwolf's avatar
thomwolf committed
22
23

import torch
Aymeric Augustin's avatar
Aymeric Augustin committed
24
from torch import nn
25
from torch.nn import CrossEntropyLoss
26
from torch.utils.checkpoint import checkpoint
thomwolf's avatar
thomwolf committed
27

Patrick von Platen's avatar
Patrick von Platen committed
28
from ...activations import ACT2FN
Sylvain Gugger's avatar
Sylvain Gugger committed
29
from ...file_utils import (
30
31
32
    DUMMY_INPUTS,
    DUMMY_MASK,
    add_start_docstrings,
33
    add_start_docstrings_to_model_forward,
34
    is_torch_fx_proxy,
35
36
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
37
from ...modeling_outputs import (
38
39
40
41
42
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
43
44
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
45
from ...utils.model_parallel_utils import assert_device_map, get_device_map
Sylvain Gugger's avatar
Sylvain Gugger committed
46
from .configuration_t5 import T5Config
Aymeric Augustin's avatar
Aymeric Augustin committed
47

thomwolf's avatar
thomwolf committed
48

Lysandre Debut's avatar
Lysandre Debut committed
49
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
50

51
_CONFIG_FOR_DOC = "T5Config"
52
_TOKENIZER_FOR_DOC = "T5Tokenizer"
53
_CHECKPOINT_FOR_DOC = "t5-small"
54

thomwolf's avatar
thomwolf committed
55
####################################################
56
# This dict contains ids and associated url
thomwolf's avatar
thomwolf committed
57
58
# for the pretrained weights provided with the models
####################################################
59
60
61
62
63
64
65
66
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "t5-small",
    "t5-base",
    "t5-large",
    "t5-3b",
    "t5-11b",
    # See all T5 models at https://huggingface.co/models?filter=t5
]
thomwolf's avatar
thomwolf committed
67

68

thomwolf's avatar
thomwolf committed
69
70
71
72
73
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
Lysandre's avatar
Lysandre committed
74
    """Load tf checkpoints in a pytorch model."""
thomwolf's avatar
thomwolf committed
75
76
    try:
        import re
77

thomwolf's avatar
thomwolf committed
78
79
80
        import numpy as np
        import tensorflow as tf
    except ImportError:
81
82
83
84
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
85
86
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
87
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
thomwolf's avatar
thomwolf committed
88
89
90
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
91
    tf_weights = {}
thomwolf's avatar
thomwolf committed
92
    for name, shape in init_vars:
93
        logger.info(f"Loading TF weight {name} with shape {shape}")
thomwolf's avatar
thomwolf committed
94
95
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
96
        tf_weights[name] = array
thomwolf's avatar
thomwolf committed
97

98
    for txt_name in names:
99
        name = txt_name.split("/")
thomwolf's avatar
thomwolf committed
100
101
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
102
103
104
105
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
106
            logger.info(f"Skipping {'/'.join(name)}")
107
108
            tf_weights.pop(txt_name, None)
            continue
109
        if "_slot_" in name[-1]:
110
            logger.info(f"Skipping {'/'.join(name)}")
111
            tf_weights.pop(txt_name, None)
thomwolf's avatar
thomwolf committed
112
113
            continue
        pointer = model
114
        array = tf_weights[txt_name]
Patrick von Platen's avatar
Patrick von Platen committed
115

thomwolf's avatar
thomwolf committed
116
        for m_name in name:
117
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
118
                scope_names = re.split(r"_(\d+)", m_name)
thomwolf's avatar
thomwolf committed
119
            else:
120
121
                scope_names = [m_name]
            if scope_names[0] in ["kernel", "scale", "embedding"]:
122
                pointer = getattr(pointer, "weight")
Patrick von Platen's avatar
Patrick von Platen committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            elif scope_names[0] == "self_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[0]
            elif scope_names[0] == "enc_dec_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[1]
            elif scope_names[0] == "dense_relu_dense":
                pointer = getattr(pointer, "layer")
                pointer = pointer[2]
            elif scope_names[0] == "rms_norm":
                if hasattr(pointer, "layer_norm"):
                    pointer = getattr(pointer, "layer_norm")
                elif hasattr(pointer, "final_layer_norm"):
                    pointer = getattr(pointer, "final_layer_norm")
137
138
139
140
141
142
            elif scope_names[0] == "scale":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
Patrick von Platen's avatar
Patrick von Platen committed
143
144
145
146
            elif scope_names[0] == "decoder" and name[1] == "logits":
                continue
            elif scope_names[0] == "logits":
                pointer = getattr(pointer, "lm_head")
Patrick von Platen's avatar
Patrick von Platen committed
147
148
149
            elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
                pointer = getattr(pointer, f"wi_{scope_names[1]}")
                continue
thomwolf's avatar
thomwolf committed
150
151
            else:
                try:
152
                    pointer = getattr(pointer, scope_names[0])
thomwolf's avatar
thomwolf committed
153
                except AttributeError:
154
                    logger.info(f"Skipping {'/'.join(name)}")
thomwolf's avatar
thomwolf committed
155
                    continue
156
157
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
158
                pointer = pointer[num]
159
        if scope_names[0] not in ["kernel", "scale", "embedding"]:
160
            pointer = getattr(pointer, "weight")
161
        if scope_names[0] != "embedding":
162
            logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
thomwolf's avatar
thomwolf committed
163
164
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
165
166
167
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
thomwolf's avatar
thomwolf committed
168
169
170
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
171
        logger.info(f"Initialize PyTorch weight {name}")
172
173
174
        pointer.data = torch.from_numpy(array.astype(np.float32))
        tf_weights.pop(txt_name, None)

175
    logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
thomwolf's avatar
thomwolf committed
176
177
178
179
180
181
    return model


####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
182
# - PreTrainedModel for the models (it-self a sub-class of nn.Module)
thomwolf's avatar
thomwolf committed
183
####################################################
184
PARALLELIZE_DOCSTRING = r"""
Stas Bekman's avatar
Stas Bekman committed
185
186
    This is an experimental feature and is a subject to change at a moment's notice.

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
    it will evenly distribute blocks across all devices.

    Args:
        device_map (:obj:`Dict[int, list]`, optional, defaults to None):
            A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
            automatically mapped to the first device (for esoteric reasons). That means that the first device should
            have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
            following number of attention modules:

                - t5-small: 6
                - t5-base: 12
                - t5-large: 24
                - t5-3b: 24
                - t5-11b: 24

    Example::
204
205

            # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
206
207
208
209
210
211
212
213
214
215
216
217
            model = T5ForConditionalGeneration.from_pretrained('t5-3b')
            device_map = {0: [0, 1, 2],

                         1: [3, 4, 5, 6, 7, 8, 9],
                         2: [10, 11, 12, 13, 14, 15, 16],
                         3: [17, 18, 19, 20, 21, 22, 23]}
            model.parallelize(device_map)
"""
DEPARALLELIZE_DOCSTRING = r"""
    Moves the model to cpu from a model parallel state.

    Example::
218
219

        # On a 4 GPU machine with t5-3b:
220
221
222
223
224
225
226
227
228
        model = T5ForConditionalGeneration.from_pretrained('t5-3b')
        device_map = {0: [0, 1, 2],

                     1: [3, 4, 5, 6, 7, 8, 9],
                     2: [10, 11, 12, 13, 14, 15, 16],
                     3: [17, 18, 19, 20, 21, 22, 23]}
        model.parallelize(device_map) # Splits the model across several devices
        model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
"""
thomwolf's avatar
thomwolf committed
229

230

thomwolf's avatar
thomwolf committed
231
232
class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
Sylvain Gugger's avatar
Sylvain Gugger committed
233
        """
234
        Construct a layernorm module in the T5 style No bias and no subtraction of mean.
thomwolf's avatar
thomwolf committed
235
        """
Julien Chaumond's avatar
Julien Chaumond committed
236
        super().__init__()
thomwolf's avatar
thomwolf committed
237
238
239
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

240
    def forward(self, hidden_states):
241
        # layer norm should always be calculated in float32
242
243
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
244

245
246
247
248
        # convert into half-precision if necessary
        if self.weight.dtype in [torch.float16, torch.bfloat16]:
            hidden_states = hidden_states.to(self.weight.dtype)

249
        return self.weight * hidden_states
thomwolf's avatar
thomwolf committed
250
251


thomwolf's avatar
thomwolf committed
252
class T5DenseReluDense(nn.Module):
thomwolf's avatar
thomwolf committed
253
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
254
        super().__init__()
thomwolf's avatar
thomwolf committed
255
256
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
thomwolf's avatar
thomwolf committed
257
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
258
259

    def forward(self, hidden_states):
260
        hidden_states = self.wi(hidden_states)
261
        hidden_states = nn.functional.relu(hidden_states)
262
263
264
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
265
266


Patrick von Platen's avatar
Patrick von Platen committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
class T5DenseGatedGeluDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.gelu_act = ACT2FN["gelu_new"]

    def forward(self, hidden_states):
        hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states


thomwolf's avatar
thomwolf committed
285
286
class T5LayerFF(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
287
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
288
289
290
291
292
293
294
295
296
        if config.feed_forward_proj == "relu":
            self.DenseReluDense = T5DenseReluDense(config)
        elif config.feed_forward_proj == "gated-gelu":
            self.DenseReluDense = T5DenseGatedGeluDense(config)
        else:
            raise ValueError(
                f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
            )

thomwolf's avatar
thomwolf committed
297
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
298
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
299
300

    def forward(self, hidden_states):
301
302
303
304
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
305
306
307


class T5Attention(nn.Module):
308
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
309
        super().__init__()
thomwolf's avatar
thomwolf committed
310
        self.is_decoder = config.is_decoder
thomwolf's avatar
thomwolf committed
311
        self.has_relative_attention_bias = has_relative_attention_bias
thomwolf's avatar
thomwolf committed
312
313

        self.relative_attention_num_buckets = config.relative_attention_num_buckets
314
        self.d_model = config.d_model
315
        self.key_value_proj_dim = config.d_kv
thomwolf's avatar
thomwolf committed
316
317
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
318
        self.inner_dim = self.n_heads * self.key_value_proj_dim
thomwolf's avatar
thomwolf committed
319

320
        # Mesh TensorFlow initialization to avoid scaling before softmax
321
322
323
324
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
thomwolf's avatar
thomwolf committed
325

thomwolf's avatar
thomwolf committed
326
327
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
thomwolf's avatar
thomwolf committed
328
        self.pruned_heads = set()
329
        self.gradient_checkpointing = False
thomwolf's avatar
thomwolf committed
330
331
332
333

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
334
335
336
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
thomwolf's avatar
thomwolf committed
337
338
339
340
341
342
343
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
344
        self.inner_dim = self.key_value_proj_dim * self.n_heads
thomwolf's avatar
thomwolf committed
345
346
347
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
348
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
thomwolf's avatar
thomwolf committed
349
350
351
352
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

Sylvain Gugger's avatar
Sylvain Gugger committed
353
354
355
356
357
358
359
        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

thomwolf's avatar
thomwolf committed
360
361
362
363
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
364
            max_distance: an integer
Sylvain Gugger's avatar
Sylvain Gugger committed
365

thomwolf's avatar
thomwolf committed
366
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
367
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
thomwolf's avatar
thomwolf committed
368
        """
369
        relative_buckets = 0
thomwolf's avatar
thomwolf committed
370
371
        if bidirectional:
            num_buckets //= 2
372
373
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
thomwolf's avatar
thomwolf committed
374
        else:
375
376
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)
thomwolf's avatar
thomwolf committed
377
378
379

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
380
        is_small = relative_position < max_exact
thomwolf's avatar
thomwolf committed
381
382

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
383
384
385
386
        relative_postion_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
387
        ).to(torch.long)
388
389
390
        relative_postion_if_large = torch.min(
            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
        )
thomwolf's avatar
thomwolf committed
391

392
393
        relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
        return relative_buckets
thomwolf's avatar
thomwolf committed
394

395
    def compute_bias(self, query_length, key_length):
Patrick von Platen's avatar
Patrick von Platen committed
396
        """Compute binned relative position bias"""
397
398
399
400
401
402
        context_position = torch.arange(
            query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
        )[:, None]
        memory_position = torch.arange(
            key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device
        )[None, :]
403
404
405
406
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
407
408
            num_buckets=self.relative_attention_num_buckets,
        )
409
410
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
thomwolf's avatar
thomwolf committed
411
412
        return values

413
414
    def forward(
        self,
415
        hidden_states,
416
        mask=None,
417
        key_value_states=None,
418
        position_bias=None,
419
        past_key_value=None,
420
        layer_head_mask=None,
421
        query_length=None,
422
        use_cache=False,
423
        output_attentions=False,
424
    ):
thomwolf's avatar
thomwolf committed
425
        """
426
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
thomwolf's avatar
thomwolf committed
427
        """
428
429
430
431
432
433
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length
434

435
        if past_key_value is not None:
436
            assert (
437
                len(past_key_value) == 2
438
            ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
439
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
440

441
        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
thomwolf's avatar
thomwolf committed
442

443
        def shape(states):
Patrick von Platen's avatar
Patrick von Platen committed
444
            """projection"""
445
446
447
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
Patrick von Platen's avatar
Patrick von Platen committed
448
            """reshape"""
449
450
451
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
Patrick von Platen's avatar
Patrick von Platen committed
452
            """projects hidden states correctly to key/query states"""
453
454
455
456
457
458
459
460
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
461

462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
thomwolf's avatar
thomwolf committed
482

483
        # compute scores
Abel's avatar
Abel committed
484
        scores = torch.matmul(
485
486
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
thomwolf's avatar
thomwolf committed
487
488

        if position_bias is None:
thomwolf's avatar
thomwolf committed
489
            if not self.has_relative_attention_bias:
490
491
492
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
493
                if self.gradient_checkpointing and self.training:
494
                    position_bias.requires_grad = True
495
496
            else:
                position_bias = self.compute_bias(real_seq_length, key_length)
497
498
499

            # if key and values are already calculated
            # we want only the last query position bias
500
            if past_key_value is not None:
501
                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
502

thomwolf's avatar
thomwolf committed
503
            if mask is not None:
504
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
505

thomwolf's avatar
thomwolf committed
506
        scores += position_bias
507
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
508
509
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
510
        attn_weights = nn.functional.dropout(
511
512
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
513
514

        # Mask heads if we want to
515
516
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask
thomwolf's avatar
thomwolf committed
517

518
519
        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)
thomwolf's avatar
thomwolf committed
520

521
522
        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
523

524
        if output_attentions:
525
            outputs = outputs + (attn_weights,)
thomwolf's avatar
thomwolf committed
526
        return outputs
thomwolf's avatar
thomwolf committed
527
528
529


class T5LayerSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
530
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
531
        super().__init__()
532
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
533
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
534
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
535

536
    def forward(
537
538
539
540
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
541
        layer_head_mask=None,
542
        past_key_value=None,
543
        use_cache=False,
544
        output_attentions=False,
545
    ):
546
        normed_hidden_states = self.layer_norm(hidden_states)
547
        attention_output = self.SelfAttention(
548
            normed_hidden_states,
549
550
            mask=attention_mask,
            position_bias=position_bias,
551
            layer_head_mask=layer_head_mask,
552
            past_key_value=past_key_value,
553
            use_cache=use_cache,
554
            output_attentions=output_attentions,
555
        )
556
557
        hidden_states = hidden_states + self.dropout(attention_output[0])
        outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them
thomwolf's avatar
thomwolf committed
558
        return outputs
thomwolf's avatar
thomwolf committed
559
560


thomwolf's avatar
thomwolf committed
561
class T5LayerCrossAttention(nn.Module):
562
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
563
        super().__init__()
564
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
thomwolf's avatar
thomwolf committed
565
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
566
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
567

568
569
570
    def forward(
        self,
        hidden_states,
571
        key_value_states,
572
573
        attention_mask=None,
        position_bias=None,
574
        layer_head_mask=None,
575
        past_key_value=None,
576
        use_cache=False,
577
        query_length=None,
578
        output_attentions=False,
579
    ):
580
        normed_hidden_states = self.layer_norm(hidden_states)
581
        attention_output = self.EncDecAttention(
582
            normed_hidden_states,
583
            mask=attention_mask,
584
            key_value_states=key_value_states,
585
            position_bias=position_bias,
586
            layer_head_mask=layer_head_mask,
587
            past_key_value=past_key_value,
588
            use_cache=use_cache,
589
            query_length=query_length,
590
            output_attentions=output_attentions,
591
        )
592
        layer_output = hidden_states + self.dropout(attention_output[0])
thomwolf's avatar
thomwolf committed
593
594
595
596
597
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


class T5Block(nn.Module):
thomwolf's avatar
thomwolf committed
598
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
599
        super().__init__()
thomwolf's avatar
thomwolf committed
600
        self.is_decoder = config.is_decoder
601
602
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
thomwolf's avatar
thomwolf committed
603
        if self.is_decoder:
604
            self.layer.append(T5LayerCrossAttention(config))
605
606

        self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
607

608
609
610
611
612
613
614
615
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
616
        layer_head_mask=None,
617
        cross_attn_layer_head_mask=None,
618
        past_key_value=None,
619
        use_cache=False,
620
        output_attentions=False,
621
        return_dict=True,
622
    ):
623

624
625
626
        if past_key_value is not None:
            assert self.is_decoder, "Only decoder can use `past_key_values`"
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
627

628
629
630
            if len(past_key_value) != expected_num_past_key_values:
                raise ValueError(
                    f"There should be {expected_num_past_key_values} past states. "
631
                    f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
632
633
                    f"Got {len(past_key_value)} past key / value states"
                )
634

635
636
            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
637
        else:
638
            self_attn_past_key_value, cross_attn_past_key_value = None, None
639

640
        self_attention_outputs = self.layer[0](
641
642
643
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
644
            layer_head_mask=layer_head_mask,
645
            past_key_value=self_attn_past_key_value,
646
            use_cache=use_cache,
647
            output_attentions=output_attentions,
648
        )
649
650
651
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

Suraj Patil's avatar
Suraj Patil committed
652
        # clamp inf values to enable fp16 training
653
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
654
655
656
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

657
658
        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
659
660
661
662
663
664
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None
thomwolf's avatar
thomwolf committed
665

666
667
            cross_attention_outputs = self.layer[1](
                hidden_states,
668
                key_value_states=encoder_hidden_states,
669
670
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
671
                layer_head_mask=cross_attn_layer_head_mask,
672
                past_key_value=cross_attn_past_key_value,
673
                query_length=query_length,
674
                use_cache=use_cache,
675
                output_attentions=output_attentions,
676
            )
thomwolf's avatar
thomwolf committed
677
            hidden_states = cross_attention_outputs[0]
678
679
680

            # clamp inf values to enable fp16 training
            if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
681
682
683
                clamp_value = torch.finfo(hidden_states.dtype).max - 1000
                hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

684
685
686
687
688
689
690
691
692
            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)
693
694
695

        # clamp inf values to enable fp16 training
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
696
697
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
698

699
        outputs = (hidden_states,)
thomwolf's avatar
thomwolf committed
700

701
702
703
704
705
        if use_cache:
            outputs = outputs + (present_key_value_state,) + attention_outputs
        else:
            outputs = outputs + attention_outputs

706
        return outputs  # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
thomwolf's avatar
thomwolf committed
707
708


thomwolf's avatar
thomwolf committed
709
class T5PreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
710
711
712
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
thomwolf's avatar
thomwolf committed
713
    """
714

thomwolf's avatar
thomwolf committed
715
716
717
    config_class = T5Config
    load_tf_weights = load_tf_weights_in_t5
    base_model_prefix = "transformer"
718
    is_parallelizable = True
719
    supports_gradient_checkpointing = True
thomwolf's avatar
thomwolf committed
720

721
722
723
724
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
725
726
        dummy_inputs = {
            "decoder_input_ids": input_ids,
727
            "input_ids": input_ids,
728
729
            "decoder_attention_mask": input_mask,
        }
730
731
        return dummy_inputs

thomwolf's avatar
thomwolf committed
732
    def _init_weights(self, module):
Patrick von Platen's avatar
Patrick von Platen committed
733
        """Initialize the weights"""
734
        factor = self.config.initializer_factor  # Used for testing weights initialization
thomwolf's avatar
thomwolf committed
735
        if isinstance(module, T5LayerNorm):
736
            module.weight.data.fill_(factor * 1.0)
737
        elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
738
739
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
740
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
741
742
743
744
        elif isinstance(module, T5DenseReluDense):
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
745
746
            module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi, "bias") and module.wi.bias is not None:
747
                module.wi.bias.data.zero_()
748
749
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
750
                module.wo.bias.data.zero_()
Patrick von Platen's avatar
Patrick von Platen committed
751
752
753
754
755
756
757
758
759
760
        elif isinstance(module, T5DenseGatedGeluDense):
            module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
                module.wi_0.bias.data.zero_()
            module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
                module.wi_1.bias.data.zero_()
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
                module.wo.bias.data.zero_()
761
762
763
764
        elif isinstance(module, T5Attention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
765
            key_value_proj_dim = self.config.d_kv
766
            n_heads = self.config.num_heads
767
            module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
768
769
            module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
            module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
770
            module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
771
            if module.has_relative_attention_bias:
772
                module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
thomwolf's avatar
thomwolf committed
773

774
775
776
777
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (T5Attention, T5Stack)):
            module.gradient_checkpointing = value

778
779
780
781
782
783
784
785
786
    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

        assert (
            decoder_start_token_id is not None
        ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"

        # shift inputs to the right
787
788
789
790
791
792
793
794
        if is_torch_fx_proxy(input_ids):
            # Item assignment is not supported natively for proxies.
            shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
            shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        else:
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
            shifted_input_ids[..., 0] = decoder_start_token_id
795
796

        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
Sylvain Gugger's avatar
Sylvain Gugger committed
797
        # replace possible -100 values in labels by `pad_token_id`
798
799
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

800
        assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
801
802
803

        return shifted_input_ids

thomwolf's avatar
thomwolf committed
804
805

class T5Stack(T5PreTrainedModel):
806
    def __init__(self, config, embed_tokens=None):
Julien Chaumond's avatar
Julien Chaumond committed
807
        super().__init__(config)
808
809

        self.embed_tokens = embed_tokens
thomwolf's avatar
thomwolf committed
810
811
        self.is_decoder = config.is_decoder

812
813
814
        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
        )
thomwolf's avatar
thomwolf committed
815
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
816
817
        self.dropout = nn.Dropout(config.dropout_rate)

818
819
        # Initialize weights and apply final processing
        self.post_init()
820
821
822
        # Model parallel
        self.model_parallel = False
        self.device_map = None
823
        self.gradient_checkpointing = False
824
825
826
827
828

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        # Check validity of device_map
        self.device_map = (
829
            get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
        )
        assert_device_map(self.device_map, len(self.block))
        self.model_parallel = True
        self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
        self.last_device = "cuda:" + str(max(self.device_map.keys()))
        # Load onto devices
        for k, v in self.device_map.items():
            for layer in v:
                cuda_device = "cuda:" + str(k)
                self.block[layer] = self.block[layer].to(cuda_device)

        # Set embed_tokens to first layer
        self.embed_tokens = self.embed_tokens.to(self.first_device)
        # Set final layer norm to last device
        self.final_layer_norm = self.final_layer_norm.to(self.last_device)

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.model_parallel = False
        self.device_map = None
        self.first_device = "cpu"
        self.last_device = "cpu"
        for i in range(len(self.block)):
            self.block[i] = self.block[i].to("cpu")
        self.embed_tokens = self.embed_tokens.to("cpu")
        self.final_layer_norm = self.final_layer_norm.to("cpu")
        torch.cuda.empty_cache()
thomwolf's avatar
thomwolf committed
857

858
859
860
861
862
863
    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

864
865
    def forward(
        self,
866
        input_ids=None,
867
868
869
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
870
        inputs_embeds=None,
871
        head_mask=None,
872
        cross_attn_head_mask=None,
873
        past_key_values=None,
874
        use_cache=None,
875
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
876
        output_hidden_states=None,
877
        return_dict=None,
878
    ):
879
880
881
882
        # Model parallel
        if self.model_parallel:
            torch.cuda.set_device(self.first_device)
            self.embed_tokens = self.embed_tokens.to(self.first_device)
883
        use_cache = use_cache if use_cache is not None else self.config.use_cache
884
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
885
886
887
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
888
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
889

890
        if input_ids is not None and inputs_embeds is not None:
891
892
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
Jonathan Chang's avatar
Jonathan Chang committed
893
                f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
894
            )
895
896
897
898
899
900
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
901
            err_msg_prefix = "decoder_" if self.is_decoder else ""
Jonathan Chang's avatar
Jonathan Chang committed
902
            raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
903
904

        if inputs_embeds is None:
905
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
906
907
908
909
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape

910
911
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
912

913
        if use_cache is True:
914
            assert self.is_decoder, f":obj:`use_cache` can only be set to `True` if {self} is used as a decoder"
915

thomwolf's avatar
thomwolf committed
916
        if attention_mask is None:
917
918
            attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
thomwolf's avatar
thomwolf committed
919
            encoder_seq_length = encoder_hidden_states.shape[1]
920
921
922
            encoder_attention_mask = torch.ones(
                batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
            )
thomwolf's avatar
thomwolf committed
923

924
925
926
        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)
927

lexhuismans's avatar
lexhuismans committed
928
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
thomwolf's avatar
thomwolf committed
929
        # ourselves in which case we just need to make it broadcastable to all heads.
930
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
thomwolf's avatar
thomwolf committed
931

932
933
934
935
936
937
938
        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
939
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
thomwolf's avatar
thomwolf committed
940
941
        else:
            encoder_extended_attention_mask = None
thomwolf's avatar
thomwolf committed
942
943

        # Prepare head mask if needed
944
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
945
        cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
946
947
948
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
949
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
thomwolf's avatar
thomwolf committed
950
        position_bias = None
thomwolf's avatar
thomwolf committed
951
        encoder_decoder_position_bias = None
thomwolf's avatar
thomwolf committed
952

953
        hidden_states = self.dropout(inputs_embeds)
954

955
        for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
956
            layer_head_mask = head_mask[i]
957
            cross_attn_layer_head_mask = cross_attn_head_mask[i]
958
959
960
961
962
963
964
965
966
967
968
969
970
971
            # Model parallel
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                # Ensure that attention_mask is always on the same device as hidden_states
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if position_bias is not None:
                    position_bias = position_bias.to(hidden_states.device)
                if encoder_hidden_states is not None:
                    encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
                if encoder_extended_attention_mask is not None:
                    encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
                if encoder_decoder_position_bias is not None:
                    encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
972
973
                if layer_head_mask is not None:
                    layer_head_mask = layer_head_mask.to(hidden_states.device)
974
975
                if cross_attn_layer_head_mask is not None:
                    cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
Joseph Liu's avatar
Joseph Liu committed
976
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
977
978
                all_hidden_states = all_hidden_states + (hidden_states,)

979
            if self.gradient_checkpointing and self.training:
980
981
                if use_cache:
                    logger.warn(
982
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return tuple(module(*inputs, use_cache, output_attentions))

                    return custom_forward

                layer_outputs = checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    extended_attention_mask,
                    position_bias,
                    encoder_hidden_states,
                    encoder_extended_attention_mask,
                    encoder_decoder_position_bias,
                    layer_head_mask,
                    cross_attn_layer_head_mask,
                    None,  # past_key_value is always None with gradient checkpointing
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_extended_attention_mask,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    layer_head_mask=layer_head_mask,
                    cross_attn_layer_head_mask=cross_attn_layer_head_mask,
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

thomwolf's avatar
thomwolf committed
1019
            # layer_outputs is a tuple with:
1020
            # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1021
1022
            if use_cache is False:
                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1023

1024
            hidden_states, present_key_value_state = layer_outputs[:2]
1025

1026
            # We share the position biases between the layers - the first layer store them
1027
1028
            # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
            # (cross-attention position bias), (cross-attention weights)
1029
1030
1031
            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1032
            # append next layer key value states
1033
1034
            if use_cache:
                present_key_value_states = present_key_value_states + (present_key_value_state,)
thomwolf's avatar
thomwolf committed
1035

1036
            if output_attentions:
1037
                all_attentions = all_attentions + (layer_outputs[3],)
1038
                if self.is_decoder:
1039
                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
thomwolf's avatar
thomwolf committed
1040

1041
1042
1043
1044
1045
1046
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

thomwolf's avatar
thomwolf committed
1047
        hidden_states = self.final_layer_norm(hidden_states)
thomwolf's avatar
thomwolf committed
1048
        hidden_states = self.dropout(hidden_states)
thomwolf's avatar
thomwolf committed
1049
1050

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
1051
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
1052
1053
            all_hidden_states = all_hidden_states + (hidden_states,)

1054
        if not return_dict:
1055
1056
            return tuple(
                v
1057
1058
1059
1060
1061
1062
1063
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
1064
1065
                if v is not None
            )
1066
        return BaseModelOutputWithPastAndCrossAttentions(
1067
1068
1069
1070
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
1071
            cross_attentions=all_cross_attentions,
1072
        )
thomwolf's avatar
thomwolf committed
1073
1074


1075
T5_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1076

1077
1078
    The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
    <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
Sylvain Gugger's avatar
Sylvain Gugger committed
1079
1080
    Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text
    denoising generative setting.
thomwolf's avatar
thomwolf committed
1081

Sylvain Gugger's avatar
Sylvain Gugger committed
1082
1083
1084
1085
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
1086
1087
1088
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
thomwolf's avatar
thomwolf committed
1089
1090

    Parameters:
1091
        config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
1092
1093
1094
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
thomwolf's avatar
thomwolf committed
1095
1096
1097
"""

T5_INPUTS_DOCSTRING = r"""
Patrick von Platen's avatar
Patrick von Platen committed
1098
1099
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1100
1101
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.
Sylvain Gugger's avatar
Sylvain Gugger committed
1102

Sylvain Gugger's avatar
Sylvain Gugger committed
1103
1104
1105
            Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            detail.
Sylvain Gugger's avatar
Sylvain Gugger committed
1106

Sylvain Gugger's avatar
Sylvain Gugger committed
1107
1108
            `What are input IDs? <../glossary.html#input-ids>`__

Sylvain Gugger's avatar
Sylvain Gugger committed
1109
1110
            To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
            <./t5.html#training>`__.
1111
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1112
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
1113
1114

            - 1 for tokens that are **not masked**,
1115
            - 0 for tokens that are **masked**.
Sylvain Gugger's avatar
Sylvain Gugger committed
1116
1117

            `What are attention masks? <../glossary.html#attention-mask>`__
1118
        decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
1119
1120
            Indices of decoder input sequence tokens in the vocabulary.

Suraj Patil's avatar
Suraj Patil committed
1121
            Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
1122
1123
1124
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            details.

1125
            `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
1126
1127
1128
1129

            T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
            :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
            :obj:`past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1130

Sylvain Gugger's avatar
Sylvain Gugger committed
1131
            To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
Suraj Patil's avatar
Suraj Patil committed
1132
            <./t5.html#training>`__.
1133
        decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1134
1135
            Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
            also be used by default.
1136
1137
1138
1139
1140
1141
1142
1143
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
            Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
            1]``:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
1144
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
1145
1146
1147
1148
1149
            1]``:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1150
1151
1152
1153
1154
1155
1156
        cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
                Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
                ``[0, 1]``:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

1157
        encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1158
1159
1160
1161
            Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
            `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
            sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
1162
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1163
1164
1165
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1166
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
Sylvain Gugger's avatar
Sylvain Gugger committed
1167
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1168
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Patrick von Platen's avatar
Patrick von Platen committed
1169
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
1170
1171
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
1172
        decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1173
            Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
Sylvain Gugger's avatar
Sylvain Gugger committed
1174
1175
1176
            representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
            have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
            :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
Sylvain Gugger's avatar
Sylvain Gugger committed
1177

Sylvain Gugger's avatar
Sylvain Gugger committed
1178
1179
            If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
            takes the value of :obj:`inputs_embeds`.
Sylvain Gugger's avatar
Sylvain Gugger committed
1180

1181
1182
1183
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1184

1185
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1186
1187
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
1188
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1189
1190
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
1191
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1192
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
thomwolf's avatar
thomwolf committed
1193
1194
"""

1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
T5_ENCODER_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.

            Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            detail.

            To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
            <./t5.html#training>`__.
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            `What are attention masks? <../glossary.html#attention-mask>`__
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
        output_attentions (:obj:`bool`, `optional`):
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
        output_hidden_states (:obj:`bool`, `optional`):
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
        return_dict (:obj:`bool`, `optional`):
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""

1234
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1235
1236
1237
1238
1239
1240
1241
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""

1242
1243

@add_start_docstrings(
NielsRogge's avatar
NielsRogge committed
1244
    "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1245
1246
    T5_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1247
class T5Model(T5PreTrainedModel):
1248
    _keys_to_ignore_on_load_missing = [
1249
1250
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1251
1252
    ]
    _keys_to_ignore_on_load_unexpected = [
1253
1254
1255
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

1256
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
1257
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1258
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
thomwolf's avatar
thomwolf committed
1259
1260

        encoder_config = copy.deepcopy(config)
1261
        encoder_config.is_decoder = False
1262
        encoder_config.use_cache = False
1263
        encoder_config.is_encoder_decoder = False
1264
        self.encoder = T5Stack(encoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1265

thomwolf's avatar
thomwolf committed
1266
1267
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1268
        decoder_config.is_encoder_decoder = False
1269
        decoder_config.num_layers = config.num_decoder_layers
1270
        self.decoder = T5Stack(decoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1271

1272
1273
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1274

1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

thomwolf's avatar
thomwolf committed
1301
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
1302
        return self.shared
thomwolf's avatar
thomwolf committed
1303
1304

    def set_input_embeddings(self, new_embeddings):
thomwolf's avatar
thomwolf committed
1305
        self.shared = new_embeddings
1306
1307
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
thomwolf's avatar
thomwolf committed
1308

1309
1310
1311
1312
1313
1314
    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

thomwolf's avatar
thomwolf committed
1315
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
1316
1317
1318
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
thomwolf's avatar
thomwolf committed
1319
1320
1321
1322
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

1323
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1324
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1325
1326
1327
1328
1329
1330
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1331
1332
        head_mask=None,
        decoder_head_mask=None,
1333
        cross_attn_head_mask=None,
1334
        encoder_outputs=None,
1335
        past_key_values=None,
1336
1337
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1338
        use_cache=None,
1339
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1340
        output_hidden_states=None,
1341
        return_dict=None,
1342
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1343
        r"""
Lysandre's avatar
Lysandre committed
1344
        Returns:
Patrick von Platen's avatar
Patrick von Platen committed
1345

Lysandre's avatar
Lysandre committed
1346
        Example::
1347

Lysandre's avatar
Lysandre committed
1348
            >>> from transformers import T5Tokenizer, T5Model
Patrick von Platen's avatar
Patrick von Platen committed
1349

Lysandre's avatar
Lysandre committed
1350
1351
            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5Model.from_pretrained('t5-small')
Patrick von Platen's avatar
Patrick von Platen committed
1352

1353
1354
            >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
            >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
Patrick von Platen's avatar
Patrick von Platen committed
1355

NielsRogge's avatar
NielsRogge committed
1356
1357
            >>> # forward pass
            >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1358
            >>> last_hidden_states = outputs.last_hidden_state
Patrick von Platen's avatar
Patrick von Platen committed
1359
        """
1360
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1361
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
thomwolf's avatar
thomwolf committed
1362

1363
1364
1365
1366
1367
1368
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

thomwolf's avatar
thomwolf committed
1369
        # Encode if needed (training, first prediction pass)
1370
1371
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
1372
1373
1374
1375
1376
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1377
                output_hidden_states=output_hidden_states,
1378
                return_dict=return_dict,
1379
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
1380
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1381
1382
1383
1384
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1385
            )
thomwolf's avatar
thomwolf committed
1386

1387
        hidden_states = encoder_outputs[0]
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
thomwolf's avatar
thomwolf committed
1400

1401
1402
1403
1404
1405
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1406
            past_key_values=past_key_values,
1407
1408
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1409
            head_mask=decoder_head_mask,
1410
            cross_attn_head_mask=cross_attn_head_mask,
1411
            use_cache=use_cache,
1412
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1413
            output_hidden_states=output_hidden_states,
1414
            return_dict=return_dict,
1415
        )
thomwolf's avatar
thomwolf committed
1416

1417
        if not return_dict:
1418
1419
1420
1421
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
1422
            past_key_values=decoder_outputs.past_key_values,
1423
1424
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1425
            cross_attentions=decoder_outputs.cross_attentions,
1426
1427
1428
1429
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1430
1431


Patrick von Platen's avatar
Patrick von Platen committed
1432
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
1433
class T5ForConditionalGeneration(T5PreTrainedModel):
1434
    _keys_to_ignore_on_load_missing = [
1435
1436
1437
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1438
1439
    ]
    _keys_to_ignore_on_load_unexpected = [
1440
1441
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]
1442

thomwolf's avatar
thomwolf committed
1443
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1444
        super().__init__(config)
1445
        self.model_dim = config.d_model
thomwolf's avatar
thomwolf committed
1446

1447
1448
1449
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
1450
        encoder_config.is_decoder = False
1451
        encoder_config.use_cache = False
1452
        encoder_config.is_encoder_decoder = False
1453
        self.encoder = T5Stack(encoder_config, self.shared)
1454
1455
1456

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1457
        decoder_config.is_encoder_decoder = False
1458
        decoder_config.num_layers = config.num_decoder_layers
1459
        self.decoder = T5Stack(decoder_config, self.shared)
1460

1461
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
1462

1463
1464
        # Initialize weights and apply final processing
        self.post_init()
thomwolf's avatar
thomwolf committed
1465

1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.decoder.first_device)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1494
1495
1496
1497
1498
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
1499
1500
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
1501

1502
1503
1504
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

thomwolf's avatar
thomwolf committed
1505
1506
1507
    def get_output_embeddings(self):
        return self.lm_head

1508
1509
    def get_encoder(self):
        return self.encoder
thomwolf's avatar
thomwolf committed
1510

1511
1512
1513
    def get_decoder(self):
        return self.decoder

1514
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1515
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1516
1517
1518
1519
1520
1521
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1522
1523
        head_mask=None,
        decoder_head_mask=None,
1524
        cross_attn_head_mask=None,
1525
        encoder_outputs=None,
1526
        past_key_values=None,
1527
1528
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1529
1530
        labels=None,
        use_cache=None,
1531
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1532
        output_hidden_states=None,
1533
        return_dict=None,
1534
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1535
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1536
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1537
1538
1539
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
1540
1541
1542
1543
1544
1545
1546
1547

        Returns:

        Examples::

            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration

            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
1548
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
1549

NielsRogge's avatar
NielsRogge committed
1550
            >>> # training
1551
            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
NielsRogge's avatar
NielsRogge committed
1552
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
1553
            >>> outputs = model(input_ids=input_ids, labels=labels)
Lysandre's avatar
Lysandre committed
1554
1555
1556
            >>> loss = outputs.loss
            >>> logits = outputs.logits

NielsRogge's avatar
NielsRogge committed
1557
1558
            >>> # inference
            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
Lysandre's avatar
Lysandre committed
1559
            >>> outputs = model.generate(input_ids)
NielsRogge's avatar
NielsRogge committed
1560
1561
            >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
            >>> # studies have shown that owning a dog is good for you.
Patrick von Platen's avatar
Patrick von Platen committed
1562
        """
1563
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1564
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1565

1566
1567
1568
1569
1570
1571
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

1572
        # Encode if needed (training, first prediction pass)
1573
        if encoder_outputs is None:
thomwolf's avatar
thomwolf committed
1574
            # Convert encoder inputs in embeddings if needed
1575
            encoder_outputs = self.encoder(
1576
1577
1578
1579
1580
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1581
                output_hidden_states=output_hidden_states,
1582
                return_dict=return_dict,
1583
            )
1584
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1585
1586
1587
1588
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1589
            )
thomwolf's avatar
thomwolf committed
1590

1591
        hidden_states = encoder_outputs[0]
1592

1593
1594
1595
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

Sylvain Gugger's avatar
Sylvain Gugger committed
1596
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1597
            # get decoder inputs from shifting lm labels to the right
Sylvain Gugger's avatar
Sylvain Gugger committed
1598
            decoder_input_ids = self._shift_right(labels)
1599

1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

1611
        # Decode
1612
1613
1614
1615
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1616
            past_key_values=past_key_values,
1617
1618
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1619
            head_mask=decoder_head_mask,
1620
            cross_attn_head_mask=cross_attn_head_mask,
1621
            use_cache=use_cache,
1622
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1623
            output_hidden_states=output_hidden_states,
1624
            return_dict=return_dict,
1625
        )
1626
1627

        sequence_output = decoder_outputs[0]
Patrick von Platen's avatar
Patrick von Platen committed
1628

1629
1630
1631
1632
1633
1634
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

Patrick von Platen's avatar
Patrick von Platen committed
1635
1636
1637
1638
1639
        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)

thomwolf's avatar
thomwolf committed
1640
        lm_logits = self.lm_head(sequence_output)
thomwolf's avatar
thomwolf committed
1641

1642
        loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1643
        if labels is not None:
Lysandre's avatar
Lysandre committed
1644
            loss_fct = CrossEntropyLoss(ignore_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1645
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1646
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
thomwolf's avatar
thomwolf committed
1647

1648
        if not return_dict:
1649
1650
1651
1652
1653
1654
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
1655
            past_key_values=decoder_outputs.past_key_values,
1656
1657
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1658
            cross_attentions=decoder_outputs.cross_attentions,
1659
1660
1661
1662
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
1663

1664
    def prepare_inputs_for_generation(
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
        self,
        input_ids,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs
1675
    ):
1676
1677
1678
1679
1680

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

1681
1682
        return {
            "decoder_input_ids": input_ids,
1683
            "past_key_values": past,
1684
1685
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
1686
1687
1688
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
1689
            "use_cache": use_cache,
1690
1691
        }

1692
1693
1694
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

1695
    def _reorder_cache(self, past, beam_idx):
1696
1697
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
1698
        if past is None:
1699
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1700
1701
1702
            return past

        reordered_decoder_past = ()
1703
        for layer_past_states in past:
1704
1705
1706
1707
1708
1709
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
1710
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1711
1712
1713
1714
1715
1716
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1717
        return reordered_decoder_past
1718
1719
1720


@add_start_docstrings(
1721
    "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
    T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
    authorized_missing_keys = [
        r"encoder\.embed_tokens\.weight",
    ]

    def __init__(self, config: T5Config):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

1738
1739
        # Initialize weights and apply final processing
        self.post_init()
1740

Lysandre Debut's avatar
Lysandre Debut committed
1741
1742
1743
1744
        # Model parallel
        self.model_parallel = False
        self.device_map = None

1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)

    def get_encoder(self):
        return self.encoder

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:

        Example::

            >>> from transformers import T5Tokenizer, T5EncoderModel
            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5EncoderModel.from_pretrained('t5-small')
            >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
            >>> outputs = model(input_ids=input_ids)
            >>> last_hidden_states = outputs.last_hidden_state
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        return encoder_outputs