modeling_t5.py 79.5 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch T5 model. """


Aymeric Augustin's avatar
Aymeric Augustin committed
18
import copy
thomwolf's avatar
thomwolf committed
19
20
import math
import os
21
import warnings
thomwolf's avatar
thomwolf committed
22
23

import torch
thomwolf's avatar
thomwolf committed
24
import torch.nn.functional as F
Aymeric Augustin's avatar
Aymeric Augustin committed
25
from torch import nn
26
from torch.nn import CrossEntropyLoss
27
from torch.utils.checkpoint import checkpoint
thomwolf's avatar
thomwolf committed
28

Patrick von Platen's avatar
Patrick von Platen committed
29
from ...activations import ACT2FN
Sylvain Gugger's avatar
Sylvain Gugger committed
30
from ...file_utils import (
31
32
33
    DUMMY_INPUTS,
    DUMMY_MASK,
    add_start_docstrings,
34
    add_start_docstrings_to_model_forward,
35
    is_torch_fx_proxy,
36
37
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
38
from ...modeling_outputs import (
39
40
41
42
43
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
44
45
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
46
from ...utils.model_parallel_utils import assert_device_map, get_device_map
Sylvain Gugger's avatar
Sylvain Gugger committed
47
from .configuration_t5 import T5Config
Aymeric Augustin's avatar
Aymeric Augustin committed
48

thomwolf's avatar
thomwolf committed
49

Lysandre Debut's avatar
Lysandre Debut committed
50
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
51

52
_CONFIG_FOR_DOC = "T5Config"
53
54
_TOKENIZER_FOR_DOC = "T5Tokenizer"

thomwolf's avatar
thomwolf committed
55
####################################################
56
# This dict contains ids and associated url
thomwolf's avatar
thomwolf committed
57
58
# for the pretrained weights provided with the models
####################################################
59
60
61
62
63
64
65
66
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "t5-small",
    "t5-base",
    "t5-large",
    "t5-3b",
    "t5-11b",
    # See all T5 models at https://huggingface.co/models?filter=t5
]
thomwolf's avatar
thomwolf committed
67

68

thomwolf's avatar
thomwolf committed
69
70
71
72
73
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
Lysandre's avatar
Lysandre committed
74
    """Load tf checkpoints in a pytorch model."""
thomwolf's avatar
thomwolf committed
75
76
    try:
        import re
77

thomwolf's avatar
thomwolf committed
78
79
80
        import numpy as np
        import tensorflow as tf
    except ImportError:
81
82
83
84
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
85
86
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
87
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
thomwolf's avatar
thomwolf committed
88
89
90
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
91
    tf_weights = {}
thomwolf's avatar
thomwolf committed
92
    for name, shape in init_vars:
93
        logger.info(f"Loading TF weight {name} with shape {shape}")
thomwolf's avatar
thomwolf committed
94
95
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
96
        tf_weights[name] = array
thomwolf's avatar
thomwolf committed
97

98
    for txt_name in names:
99
        name = txt_name.split("/")
thomwolf's avatar
thomwolf committed
100
101
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
102
103
104
105
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
106
            logger.info(f"Skipping {'/'.join(name)}")
107
108
            tf_weights.pop(txt_name, None)
            continue
109
        if "_slot_" in name[-1]:
110
            logger.info(f"Skipping {'/'.join(name)}")
111
            tf_weights.pop(txt_name, None)
thomwolf's avatar
thomwolf committed
112
113
            continue
        pointer = model
114
        array = tf_weights[txt_name]
Patrick von Platen's avatar
Patrick von Platen committed
115

thomwolf's avatar
thomwolf committed
116
        for m_name in name:
117
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
118
                scope_names = re.split(r"_(\d+)", m_name)
thomwolf's avatar
thomwolf committed
119
            else:
120
121
                scope_names = [m_name]
            if scope_names[0] in ["kernel", "scale", "embedding"]:
122
                pointer = getattr(pointer, "weight")
Patrick von Platen's avatar
Patrick von Platen committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            elif scope_names[0] == "self_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[0]
            elif scope_names[0] == "enc_dec_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[1]
            elif scope_names[0] == "dense_relu_dense":
                pointer = getattr(pointer, "layer")
                pointer = pointer[2]
            elif scope_names[0] == "rms_norm":
                if hasattr(pointer, "layer_norm"):
                    pointer = getattr(pointer, "layer_norm")
                elif hasattr(pointer, "final_layer_norm"):
                    pointer = getattr(pointer, "final_layer_norm")
137
138
139
140
141
142
            elif scope_names[0] == "scale":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
Patrick von Platen's avatar
Patrick von Platen committed
143
144
145
146
            elif scope_names[0] == "decoder" and name[1] == "logits":
                continue
            elif scope_names[0] == "logits":
                pointer = getattr(pointer, "lm_head")
Patrick von Platen's avatar
Patrick von Platen committed
147
148
149
            elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
                pointer = getattr(pointer, f"wi_{scope_names[1]}")
                continue
thomwolf's avatar
thomwolf committed
150
151
            else:
                try:
152
                    pointer = getattr(pointer, scope_names[0])
thomwolf's avatar
thomwolf committed
153
                except AttributeError:
154
                    logger.info(f"Skipping {'/'.join(name)}")
thomwolf's avatar
thomwolf committed
155
                    continue
156
157
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
158
                pointer = pointer[num]
159
        if scope_names[0] not in ["kernel", "scale", "embedding"]:
160
            pointer = getattr(pointer, "weight")
161
        if scope_names[0] != "embedding":
162
            logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
thomwolf's avatar
thomwolf committed
163
164
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
165
166
167
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
thomwolf's avatar
thomwolf committed
168
169
170
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
171
        logger.info(f"Initialize PyTorch weight {name}")
172
173
174
        pointer.data = torch.from_numpy(array.astype(np.float32))
        tf_weights.pop(txt_name, None)

175
    logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
thomwolf's avatar
thomwolf committed
176
177
178
179
180
181
182
183
    return model


####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
####################################################
184
PARALLELIZE_DOCSTRING = r"""
Stas Bekman's avatar
Stas Bekman committed
185
186
    This is an experimental feature and is a subject to change at a moment's notice.

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
    it will evenly distribute blocks across all devices.

    Args:
        device_map (:obj:`Dict[int, list]`, optional, defaults to None):
            A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
            automatically mapped to the first device (for esoteric reasons). That means that the first device should
            have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
            following number of attention modules:

                - t5-small: 6
                - t5-base: 12
                - t5-large: 24
                - t5-3b: 24
                - t5-11b: 24

    Example::
204
205

            # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
206
207
208
209
210
211
212
213
214
215
216
217
            model = T5ForConditionalGeneration.from_pretrained('t5-3b')
            device_map = {0: [0, 1, 2],

                         1: [3, 4, 5, 6, 7, 8, 9],
                         2: [10, 11, 12, 13, 14, 15, 16],
                         3: [17, 18, 19, 20, 21, 22, 23]}
            model.parallelize(device_map)
"""
DEPARALLELIZE_DOCSTRING = r"""
    Moves the model to cpu from a model parallel state.

    Example::
218
219

        # On a 4 GPU machine with t5-3b:
220
221
222
223
224
225
226
227
228
        model = T5ForConditionalGeneration.from_pretrained('t5-3b')
        device_map = {0: [0, 1, 2],

                     1: [3, 4, 5, 6, 7, 8, 9],
                     2: [10, 11, 12, 13, 14, 15, 16],
                     3: [17, 18, 19, 20, 21, 22, 23]}
        model.parallelize(device_map) # Splits the model across several devices
        model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
"""
thomwolf's avatar
thomwolf committed
229

230

thomwolf's avatar
thomwolf committed
231
232
class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
Sylvain Gugger's avatar
Sylvain Gugger committed
233
        """
234
        Construct a layernorm module in the T5 style No bias and no subtraction of mean.
thomwolf's avatar
thomwolf committed
235
        """
Julien Chaumond's avatar
Julien Chaumond committed
236
        super().__init__()
thomwolf's avatar
thomwolf committed
237
238
239
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

240
    def forward(self, hidden_states):
241
        # layer norm should always be calculated in float32
242
243
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
244

245
        # convert into float16 if necessary
246
        if self.weight.dtype == torch.float16:
247
248
            hidden_states = hidden_states.to(torch.float16)
        return self.weight * hidden_states
thomwolf's avatar
thomwolf committed
249
250


thomwolf's avatar
thomwolf committed
251
class T5DenseReluDense(nn.Module):
thomwolf's avatar
thomwolf committed
252
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
253
        super().__init__()
thomwolf's avatar
thomwolf committed
254
255
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
thomwolf's avatar
thomwolf committed
256
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
257
258

    def forward(self, hidden_states):
259
260
261
262
263
        hidden_states = self.wi(hidden_states)
        hidden_states = F.relu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
264
265


Patrick von Platen's avatar
Patrick von Platen committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
class T5DenseGatedGeluDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.gelu_act = ACT2FN["gelu_new"]

    def forward(self, hidden_states):
        hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states


thomwolf's avatar
thomwolf committed
284
285
class T5LayerFF(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
286
        super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
287
288
289
290
291
292
293
294
295
        if config.feed_forward_proj == "relu":
            self.DenseReluDense = T5DenseReluDense(config)
        elif config.feed_forward_proj == "gated-gelu":
            self.DenseReluDense = T5DenseGatedGeluDense(config)
        else:
            raise ValueError(
                f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
            )

thomwolf's avatar
thomwolf committed
296
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
297
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
298
299

    def forward(self, hidden_states):
300
301
302
303
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
304
305
306


class T5Attention(nn.Module):
307
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
308
        super().__init__()
thomwolf's avatar
thomwolf committed
309
        self.is_decoder = config.is_decoder
thomwolf's avatar
thomwolf committed
310
        self.has_relative_attention_bias = has_relative_attention_bias
thomwolf's avatar
thomwolf committed
311
312

        self.relative_attention_num_buckets = config.relative_attention_num_buckets
313
        self.d_model = config.d_model
314
        self.key_value_proj_dim = config.d_kv
thomwolf's avatar
thomwolf committed
315
316
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
317
        self.inner_dim = self.n_heads * self.key_value_proj_dim
thomwolf's avatar
thomwolf committed
318

319
        # Mesh TensorFlow initialization to avoid scaling before softmax
320
321
322
323
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
thomwolf's avatar
thomwolf committed
324

thomwolf's avatar
thomwolf committed
325
326
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
thomwolf's avatar
thomwolf committed
327
        self.pruned_heads = set()
328
        self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
thomwolf's avatar
thomwolf committed
329
330
331
332

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
333
334
335
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
thomwolf's avatar
thomwolf committed
336
337
338
339
340
341
342
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
343
        self.inner_dim = self.key_value_proj_dim * self.n_heads
thomwolf's avatar
thomwolf committed
344
345
346
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
347
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
thomwolf's avatar
thomwolf committed
348
349
350
351
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

Sylvain Gugger's avatar
Sylvain Gugger committed
352
353
354
355
356
357
358
        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

thomwolf's avatar
thomwolf committed
359
360
361
362
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
363
            max_distance: an integer
Sylvain Gugger's avatar
Sylvain Gugger committed
364

thomwolf's avatar
thomwolf committed
365
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
366
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
thomwolf's avatar
thomwolf committed
367
        """
368
        relative_buckets = 0
thomwolf's avatar
thomwolf committed
369
370
        if bidirectional:
            num_buckets //= 2
371
372
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
thomwolf's avatar
thomwolf committed
373
        else:
374
375
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)
thomwolf's avatar
thomwolf committed
376
377
378

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
379
        is_small = relative_position < max_exact
thomwolf's avatar
thomwolf committed
380
381

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
382
383
384
385
        relative_postion_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
386
        ).to(torch.long)
387
388
389
        relative_postion_if_large = torch.min(
            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
        )
thomwolf's avatar
thomwolf committed
390

391
392
        relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
        return relative_buckets
thomwolf's avatar
thomwolf committed
393

394
    def compute_bias(self, query_length, key_length):
Patrick von Platen's avatar
Patrick von Platen committed
395
        """Compute binned relative position bias"""
396
397
398
399
400
401
        context_position = torch.arange(query_length, dtype=torch.long)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
402
403
            num_buckets=self.relative_attention_num_buckets,
        )
404
405
406
        relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
thomwolf's avatar
thomwolf committed
407
408
        return values

409
410
    def forward(
        self,
411
        hidden_states,
412
        mask=None,
413
        key_value_states=None,
414
        position_bias=None,
415
        past_key_value=None,
416
        layer_head_mask=None,
417
        query_length=None,
418
        use_cache=False,
419
        output_attentions=False,
420
    ):
thomwolf's avatar
thomwolf committed
421
        """
422
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
thomwolf's avatar
thomwolf committed
423
        """
424
425
426
427
428
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

429
        int_seq_length = int(seq_length)
Sylvain Gugger's avatar
Style  
Sylvain Gugger committed
430

431
        real_seq_length = seq_length
432

433
        if past_key_value is not None:
434
            assert (
435
                len(past_key_value) == 2
436
            ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
437
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
438

439
        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
thomwolf's avatar
thomwolf committed
440

441
        def shape(states):
Patrick von Platen's avatar
Patrick von Platen committed
442
            """projection"""
443
444
445
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
Patrick von Platen's avatar
Patrick von Platen committed
446
            """reshape"""
447
448
449
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
Patrick von Platen's avatar
Patrick von Platen committed
450
            """projects hidden states correctly to key/query states"""
451
452
453
454
455
456
457
458
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
459

460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
thomwolf's avatar
thomwolf committed
480

481
        # compute scores
Abel's avatar
Abel committed
482
        scores = torch.matmul(
483
484
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
thomwolf's avatar
thomwolf committed
485
486

        if position_bias is None:
thomwolf's avatar
thomwolf committed
487
            if not self.has_relative_attention_bias:
488
489
490
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
491
492
                if self.training and self.gradient_checkpointing:
                    position_bias.requires_grad = True
493
494
            else:
                position_bias = self.compute_bias(real_seq_length, key_length)
495
496
497

            # if key and values are already calculated
            # we want only the last query position bias
498
            if past_key_value is not None:
499
                position_bias = position_bias[:, :, -int_seq_length:, :]
500

thomwolf's avatar
thomwolf committed
501
            if mask is not None:
502
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
503

thomwolf's avatar
thomwolf committed
504
        scores += position_bias
505
506
507
508
509
510
        attn_weights = F.softmax(scores.float(), dim=-1).type_as(
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
        attn_weights = F.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
511
512

        # Mask heads if we want to
513
514
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask
thomwolf's avatar
thomwolf committed
515

516
517
        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)
thomwolf's avatar
thomwolf committed
518

519
520
        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
521

522
        if output_attentions:
523
            outputs = outputs + (attn_weights,)
thomwolf's avatar
thomwolf committed
524
        return outputs
thomwolf's avatar
thomwolf committed
525
526
527


class T5LayerSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
528
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
529
        super().__init__()
530
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
531
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
532
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
533

534
    def forward(
535
536
537
538
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
539
        layer_head_mask=None,
540
        past_key_value=None,
541
        use_cache=False,
542
        output_attentions=False,
543
    ):
544
        normed_hidden_states = self.layer_norm(hidden_states)
545
        attention_output = self.SelfAttention(
546
            normed_hidden_states,
547
548
            mask=attention_mask,
            position_bias=position_bias,
549
            layer_head_mask=layer_head_mask,
550
            past_key_value=past_key_value,
551
            use_cache=use_cache,
552
            output_attentions=output_attentions,
553
        )
554
555
        hidden_states = hidden_states + self.dropout(attention_output[0])
        outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them
thomwolf's avatar
thomwolf committed
556
        return outputs
thomwolf's avatar
thomwolf committed
557
558


thomwolf's avatar
thomwolf committed
559
class T5LayerCrossAttention(nn.Module):
560
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
561
        super().__init__()
562
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
thomwolf's avatar
thomwolf committed
563
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
564
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
565

566
567
568
    def forward(
        self,
        hidden_states,
569
        key_value_states,
570
571
        attention_mask=None,
        position_bias=None,
572
        layer_head_mask=None,
573
        past_key_value=None,
574
        use_cache=False,
575
        query_length=None,
576
        output_attentions=False,
577
    ):
578
        normed_hidden_states = self.layer_norm(hidden_states)
579
        attention_output = self.EncDecAttention(
580
            normed_hidden_states,
581
            mask=attention_mask,
582
            key_value_states=key_value_states,
583
            position_bias=position_bias,
584
            layer_head_mask=layer_head_mask,
585
            past_key_value=past_key_value,
586
            use_cache=use_cache,
587
            query_length=query_length,
588
            output_attentions=output_attentions,
589
        )
590
        layer_output = hidden_states + self.dropout(attention_output[0])
thomwolf's avatar
thomwolf committed
591
592
593
594
595
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


class T5Block(nn.Module):
thomwolf's avatar
thomwolf committed
596
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
597
        super().__init__()
thomwolf's avatar
thomwolf committed
598
        self.is_decoder = config.is_decoder
599
600
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
thomwolf's avatar
thomwolf committed
601
        if self.is_decoder:
602
            self.layer.append(T5LayerCrossAttention(config))
603
604

        self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
605

606
607
608
609
610
611
612
613
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
614
        layer_head_mask=None,
615
        cross_attn_layer_head_mask=None,
616
        past_key_value=None,
617
        use_cache=False,
618
        output_attentions=False,
619
        return_dict=True,
620
    ):
621

622
623
624
        if past_key_value is not None:
            assert self.is_decoder, "Only decoder can use `past_key_values`"
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
625

626
627
628
629
630
631
            if len(past_key_value) != expected_num_past_key_values:
                raise ValueError(
                    f"There should be {expected_num_past_key_values} past states. "
                    f"{'2 (past / key) for cross attention' if expected_num_past_key_values == 4 else ''}."
                    f"Got {len(past_key_value)} past key / value states"
                )
632

633
634
            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
635
        else:
636
            self_attn_past_key_value, cross_attn_past_key_value = None, None
637

638
        self_attention_outputs = self.layer[0](
639
640
641
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
642
            layer_head_mask=layer_head_mask,
643
            past_key_value=self_attn_past_key_value,
644
            use_cache=use_cache,
645
            output_attentions=output_attentions,
646
        )
647
648
649
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

Suraj Patil's avatar
Suraj Patil committed
650
        # clamp inf values to enable fp16 training
651
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
652
653
654
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

655
656
        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
657
658
659
660
661
662
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None
thomwolf's avatar
thomwolf committed
663

664
665
            cross_attention_outputs = self.layer[1](
                hidden_states,
666
                key_value_states=encoder_hidden_states,
667
668
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
669
                layer_head_mask=cross_attn_layer_head_mask,
670
                past_key_value=cross_attn_past_key_value,
671
                query_length=query_length,
672
                use_cache=use_cache,
673
                output_attentions=output_attentions,
674
            )
thomwolf's avatar
thomwolf committed
675
            hidden_states = cross_attention_outputs[0]
676
677
678

            # clamp inf values to enable fp16 training
            if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
679
680
681
                clamp_value = torch.finfo(hidden_states.dtype).max - 1000
                hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

682
683
684
685
686
687
688
689
690
            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)
691
692
693

        # clamp inf values to enable fp16 training
        if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
Suraj Patil's avatar
Suraj Patil committed
694
695
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
696

697
        outputs = (hidden_states,)
thomwolf's avatar
thomwolf committed
698

699
700
701
702
703
        if use_cache:
            outputs = outputs + (present_key_value_state,) + attention_outputs
        else:
            outputs = outputs + attention_outputs

704
        return outputs  # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
thomwolf's avatar
thomwolf committed
705
706


thomwolf's avatar
thomwolf committed
707
class T5PreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
708
709
710
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
thomwolf's avatar
thomwolf committed
711
    """
712

thomwolf's avatar
thomwolf committed
713
714
715
    config_class = T5Config
    load_tf_weights = load_tf_weights_in_t5
    base_model_prefix = "transformer"
716
    is_parallelizable = True
thomwolf's avatar
thomwolf committed
717

718
719
720
721
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
722
723
        dummy_inputs = {
            "decoder_input_ids": input_ids,
724
            "input_ids": input_ids,
725
726
            "decoder_attention_mask": input_mask,
        }
727
728
        return dummy_inputs

thomwolf's avatar
thomwolf committed
729
    def _init_weights(self, module):
Patrick von Platen's avatar
Patrick von Platen committed
730
        """Initialize the weights"""
731
        factor = self.config.initializer_factor  # Used for testing weights initialization
thomwolf's avatar
thomwolf committed
732
        if isinstance(module, T5LayerNorm):
733
            module.weight.data.fill_(factor * 1.0)
734
        elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
735
736
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
737
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
738
739
740
741
        elif isinstance(module, T5DenseReluDense):
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
742
743
            module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi, "bias") and module.wi.bias is not None:
744
                module.wi.bias.data.zero_()
745
746
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
747
                module.wo.bias.data.zero_()
Patrick von Platen's avatar
Patrick von Platen committed
748
749
750
751
752
753
754
755
756
757
        elif isinstance(module, T5DenseGatedGeluDense):
            module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
                module.wi_0.bias.data.zero_()
            module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
                module.wi_1.bias.data.zero_()
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
                module.wo.bias.data.zero_()
758
759
760
761
        elif isinstance(module, T5Attention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
762
            key_value_proj_dim = self.config.d_kv
763
            n_heads = self.config.num_heads
764
            module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
765
766
            module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
            module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
767
            module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
768
            if module.has_relative_attention_bias:
769
                module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
thomwolf's avatar
thomwolf committed
770

771
772
773
774
775
776
777
778
779
    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

        assert (
            decoder_start_token_id is not None
        ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"

        # shift inputs to the right
780
781
782
783
784
785
786
787
        if is_torch_fx_proxy(input_ids):
            # Item assignment is not supported natively for proxies.
            shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
            shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        else:
            shifted_input_ids = input_ids.new_zeros(input_ids.shape)
            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
            shifted_input_ids[..., 0] = decoder_start_token_id
788
789

        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
Sylvain Gugger's avatar
Sylvain Gugger committed
790
        # replace possible -100 values in labels by `pad_token_id`
791
792
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

793
        assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
794
795
796

        return shifted_input_ids

thomwolf's avatar
thomwolf committed
797
798

class T5Stack(T5PreTrainedModel):
799
    def __init__(self, config, embed_tokens=None):
Julien Chaumond's avatar
Julien Chaumond committed
800
        super().__init__(config)
801
802

        self.embed_tokens = embed_tokens
thomwolf's avatar
thomwolf committed
803
804
        self.is_decoder = config.is_decoder

805
806
807
        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
        )
thomwolf's avatar
thomwolf committed
808
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
809
810
811
        self.dropout = nn.Dropout(config.dropout_rate)

        self.init_weights()
812
813
814
815
816
817
818
819
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        # Check validity of device_map
        self.device_map = (
820
            get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
        )
        assert_device_map(self.device_map, len(self.block))
        self.model_parallel = True
        self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
        self.last_device = "cuda:" + str(max(self.device_map.keys()))
        # Load onto devices
        for k, v in self.device_map.items():
            for layer in v:
                cuda_device = "cuda:" + str(k)
                self.block[layer] = self.block[layer].to(cuda_device)

        # Set embed_tokens to first layer
        self.embed_tokens = self.embed_tokens.to(self.first_device)
        # Set final layer norm to last device
        self.final_layer_norm = self.final_layer_norm.to(self.last_device)

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.model_parallel = False
        self.device_map = None
        self.first_device = "cpu"
        self.last_device = "cpu"
        for i in range(len(self.block)):
            self.block[i] = self.block[i].to("cpu")
        self.embed_tokens = self.embed_tokens.to("cpu")
        self.final_layer_norm = self.final_layer_norm.to("cpu")
        torch.cuda.empty_cache()
thomwolf's avatar
thomwolf committed
848

849
850
851
852
853
854
    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

855
856
    def forward(
        self,
857
        input_ids=None,
858
859
860
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
861
        inputs_embeds=None,
862
        head_mask=None,
863
        cross_attn_head_mask=None,
864
        past_key_values=None,
865
        use_cache=None,
866
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
867
        output_hidden_states=None,
868
        return_dict=None,
869
    ):
870
871
872
873
        # Model parallel
        if self.model_parallel:
            torch.cuda.set_device(self.first_device)
            self.embed_tokens = self.embed_tokens.to(self.first_device)
874
        use_cache = use_cache if use_cache is not None else self.config.use_cache
875
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
876
877
878
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
879
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
880

881
        if input_ids is not None and inputs_embeds is not None:
882
883
884
885
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
                f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
            )
886
887
888
889
890
891
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
892
893
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
894
895

        if inputs_embeds is None:
896
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
897
898
899
900
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape

901
902
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
903

904
        if use_cache is True:
905
            assert self.is_decoder, f":obj:`use_cache` can only be set to `True` if {self} is used as a decoder"
906

thomwolf's avatar
thomwolf committed
907
        if attention_mask is None:
908
909
            attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
thomwolf's avatar
thomwolf committed
910
            encoder_seq_length = encoder_hidden_states.shape[1]
911
912
913
            encoder_attention_mask = torch.ones(
                batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
            )
thomwolf's avatar
thomwolf committed
914

915
916
917
        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)
918

lexhuismans's avatar
lexhuismans committed
919
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
thomwolf's avatar
thomwolf committed
920
        # ourselves in which case we just need to make it broadcastable to all heads.
921
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
thomwolf's avatar
thomwolf committed
922

923
924
925
926
927
928
929
        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
930
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
thomwolf's avatar
thomwolf committed
931
932
        else:
            encoder_extended_attention_mask = None
thomwolf's avatar
thomwolf committed
933
934

        # Prepare head mask if needed
935
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
936
        cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
937
938
939
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
940
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
thomwolf's avatar
thomwolf committed
941
        position_bias = None
thomwolf's avatar
thomwolf committed
942
        encoder_decoder_position_bias = None
thomwolf's avatar
thomwolf committed
943

944
        hidden_states = self.dropout(inputs_embeds)
945

946
        for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
947
            layer_head_mask = head_mask[i]
948
            cross_attn_layer_head_mask = cross_attn_head_mask[i]
949
950
951
952
953
954
955
956
957
958
959
960
961
962
            # Model parallel
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                # Ensure that attention_mask is always on the same device as hidden_states
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if position_bias is not None:
                    position_bias = position_bias.to(hidden_states.device)
                if encoder_hidden_states is not None:
                    encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
                if encoder_extended_attention_mask is not None:
                    encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
                if encoder_decoder_position_bias is not None:
                    encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
963
964
                if layer_head_mask is not None:
                    layer_head_mask = layer_head_mask.to(hidden_states.device)
965
966
                if cross_attn_layer_head_mask is not None:
                    cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
Joseph Liu's avatar
Joseph Liu committed
967
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
968
969
                all_hidden_states = all_hidden_states + (hidden_states,)

970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
            if getattr(self.config, "gradient_checkpointing", False) and self.training:
                if use_cache:
                    logger.warn(
                        "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
                        "`use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return tuple(module(*inputs, use_cache, output_attentions))

                    return custom_forward

                layer_outputs = checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    extended_attention_mask,
                    position_bias,
                    encoder_hidden_states,
                    encoder_extended_attention_mask,
                    encoder_decoder_position_bias,
                    layer_head_mask,
                    cross_attn_layer_head_mask,
                    None,  # past_key_value is always None with gradient checkpointing
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_extended_attention_mask,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    layer_head_mask=layer_head_mask,
                    cross_attn_layer_head_mask=cross_attn_layer_head_mask,
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

thomwolf's avatar
thomwolf committed
1011
            # layer_outputs is a tuple with:
1012
            # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
1013
1014
            if use_cache is False:
                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1015
            hidden_states, present_key_value_state = layer_outputs[:2]
1016

1017
1018
1019
1020
1021
1022
            # We share the position biases between the layers - the first layer store them
            # layer_outputs = hidden-states, key-value-states (self-attention weights),
            # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1023
            # append next layer key value states
1024
1025
            if use_cache:
                present_key_value_states = present_key_value_states + (present_key_value_state,)
thomwolf's avatar
thomwolf committed
1026

1027
            if output_attentions:
1028
                all_attentions = all_attentions + (layer_outputs[3],)
1029
                if self.is_decoder:
1030
                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
thomwolf's avatar
thomwolf committed
1031

1032
1033
1034
1035
1036
1037
            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

thomwolf's avatar
thomwolf committed
1038
        hidden_states = self.final_layer_norm(hidden_states)
thomwolf's avatar
thomwolf committed
1039
        hidden_states = self.dropout(hidden_states)
thomwolf's avatar
thomwolf committed
1040
1041

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
1042
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
1043
1044
            all_hidden_states = all_hidden_states + (hidden_states,)

1045
        if not return_dict:
1046
1047
            return tuple(
                v
1048
1049
1050
1051
1052
1053
1054
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
1055
1056
                if v is not None
            )
1057
        return BaseModelOutputWithPastAndCrossAttentions(
1058
1059
1060
1061
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
1062
            cross_attentions=all_cross_attentions,
1063
        )
thomwolf's avatar
thomwolf committed
1064
1065


1066
T5_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1067

1068
1069
    The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
    <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
Sylvain Gugger's avatar
Sylvain Gugger committed
1070
1071
    Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text
    denoising generative setting.
thomwolf's avatar
thomwolf committed
1072

Sylvain Gugger's avatar
Sylvain Gugger committed
1073
1074
1075
1076
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
1077
1078
1079
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
thomwolf's avatar
thomwolf committed
1080
1081

    Parameters:
1082
        config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
1083
1084
1085
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
thomwolf's avatar
thomwolf committed
1086
1087
1088
"""

T5_INPUTS_DOCSTRING = r"""
Patrick von Platen's avatar
Patrick von Platen committed
1089
1090
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1091
1092
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.
Sylvain Gugger's avatar
Sylvain Gugger committed
1093

Sylvain Gugger's avatar
Sylvain Gugger committed
1094
1095
1096
            Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            detail.
Sylvain Gugger's avatar
Sylvain Gugger committed
1097

Sylvain Gugger's avatar
Sylvain Gugger committed
1098
1099
            `What are input IDs? <../glossary.html#input-ids>`__

Sylvain Gugger's avatar
Sylvain Gugger committed
1100
1101
            To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
            <./t5.html#training>`__.
1102
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1103
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
1104
1105

            - 1 for tokens that are **not masked**,
1106
            - 0 for tokens that are **masked**.
Sylvain Gugger's avatar
Sylvain Gugger committed
1107
1108

            `What are attention masks? <../glossary.html#attention-mask>`__
1109
        decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
1110
1111
            Indices of decoder input sequence tokens in the vocabulary.

Suraj Patil's avatar
Suraj Patil committed
1112
            Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
1113
1114
1115
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            details.

1116
            `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
1117
1118
1119
1120

            T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
            :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
            :obj:`past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1121

Sylvain Gugger's avatar
Sylvain Gugger committed
1122
            To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
Suraj Patil's avatar
Suraj Patil committed
1123
            <./t5.html#training>`__.
1124
        decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1125
1126
            Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
            also be used by default.
1127
1128
1129
1130
1131
1132
1133
1134
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
            Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
            1]``:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
1135
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
1136
1137
1138
1139
1140
            1]``:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

1141
1142
1143
1144
1145
1146
1147
        cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
                Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
                ``[0, 1]``:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

1148
        encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1149
1150
1151
1152
            Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
            `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
            sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
1153
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1154
1155
1156
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1157
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
Sylvain Gugger's avatar
Sylvain Gugger committed
1158
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1159
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Patrick von Platen's avatar
Patrick von Platen committed
1160
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
1161
1162
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
1163
        decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1164
            Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
Sylvain Gugger's avatar
Sylvain Gugger committed
1165
1166
1167
            representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
            have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
            :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
Sylvain Gugger's avatar
Sylvain Gugger committed
1168

Sylvain Gugger's avatar
Sylvain Gugger committed
1169
1170
            If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
            takes the value of :obj:`inputs_embeds`.
Sylvain Gugger's avatar
Sylvain Gugger committed
1171

1172
1173
1174
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
1175

1176
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1177
1178
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
1179
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1180
1181
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
1182
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1183
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
thomwolf's avatar
thomwolf committed
1184
1185
"""

1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
T5_ENCODER_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.

            Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            detail.

            To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
            <./t5.html#training>`__.
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            `What are attention masks? <../glossary.html#attention-mask>`__
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
        output_attentions (:obj:`bool`, `optional`):
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
        output_hidden_states (:obj:`bool`, `optional`):
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
        return_dict (:obj:`bool`, `optional`):
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""

1225
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1226
1227
1228
1229
1230
1231
1232
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""

1233
1234
1235
1236
1237

@add_start_docstrings(
    "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
    T5_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
1238
class T5Model(T5PreTrainedModel):
1239
    _keys_to_ignore_on_load_missing = [
1240
1241
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1242
1243
    ]
    _keys_to_ignore_on_load_unexpected = [
1244
1245
1246
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

1247
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
1248
        super().__init__(config)
thomwolf's avatar
thomwolf committed
1249
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
thomwolf's avatar
thomwolf committed
1250
1251

        encoder_config = copy.deepcopy(config)
1252
        encoder_config.is_decoder = False
1253
        encoder_config.use_cache = False
1254
        encoder_config.is_encoder_decoder = False
1255
        self.encoder = T5Stack(encoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1256

thomwolf's avatar
thomwolf committed
1257
1258
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1259
        decoder_config.is_encoder_decoder = False
1260
        decoder_config.num_layers = config.num_decoder_layers
1261
        self.decoder = T5Stack(decoder_config, self.shared)
thomwolf's avatar
thomwolf committed
1262
1263
1264

        self.init_weights()

1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

thomwolf's avatar
thomwolf committed
1291
    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
1292
        return self.shared
thomwolf's avatar
thomwolf committed
1293
1294

    def set_input_embeddings(self, new_embeddings):
thomwolf's avatar
thomwolf committed
1295
        self.shared = new_embeddings
1296
1297
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
thomwolf's avatar
thomwolf committed
1298

1299
1300
1301
1302
1303
1304
    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

thomwolf's avatar
thomwolf committed
1305
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
1306
1307
1308
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
thomwolf's avatar
thomwolf committed
1309
1310
1311
1312
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

1313
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1314
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1315
1316
1317
1318
1319
1320
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1321
1322
        head_mask=None,
        decoder_head_mask=None,
1323
        cross_attn_head_mask=None,
1324
        encoder_outputs=None,
1325
        past_key_values=None,
1326
1327
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1328
        use_cache=None,
1329
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1330
        output_hidden_states=None,
1331
        return_dict=None,
1332
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1333
        r"""
Lysandre's avatar
Lysandre committed
1334
        Returns:
Patrick von Platen's avatar
Patrick von Platen committed
1335

Lysandre's avatar
Lysandre committed
1336
        Example::
1337

Lysandre's avatar
Lysandre committed
1338
            >>> from transformers import T5Tokenizer, T5Model
Patrick von Platen's avatar
Patrick von Platen committed
1339

Lysandre's avatar
Lysandre committed
1340
1341
            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5Model.from_pretrained('t5-small')
Patrick von Platen's avatar
Patrick von Platen committed
1342

1343
1344
            >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
            >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
1345
            >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
Patrick von Platen's avatar
Patrick von Platen committed
1346

1347
            >>> last_hidden_states = outputs.last_hidden_state
Patrick von Platen's avatar
Patrick von Platen committed
1348
        """
1349
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1350
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
thomwolf's avatar
thomwolf committed
1351

1352
1353
1354
1355
1356
1357
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

thomwolf's avatar
thomwolf committed
1358
        # Encode if needed (training, first prediction pass)
1359
1360
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
1361
1362
1363
1364
1365
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1366
                output_hidden_states=output_hidden_states,
1367
                return_dict=return_dict,
1368
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
1369
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1370
1371
1372
1373
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1374
            )
thomwolf's avatar
thomwolf committed
1375

1376
        hidden_states = encoder_outputs[0]
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
thomwolf's avatar
thomwolf committed
1389

1390
1391
1392
1393
1394
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1395
            past_key_values=past_key_values,
1396
1397
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1398
            head_mask=decoder_head_mask,
1399
            cross_attn_head_mask=cross_attn_head_mask,
1400
            use_cache=use_cache,
1401
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1402
            output_hidden_states=output_hidden_states,
1403
            return_dict=return_dict,
1404
        )
thomwolf's avatar
thomwolf committed
1405

1406
        if not return_dict:
1407
1408
1409
1410
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
1411
            past_key_values=decoder_outputs.past_key_values,
1412
1413
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1414
            cross_attentions=decoder_outputs.cross_attentions,
1415
1416
1417
1418
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1419
1420


Patrick von Platen's avatar
Patrick von Platen committed
1421
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
1422
class T5ForConditionalGeneration(T5PreTrainedModel):
1423
    _keys_to_ignore_on_load_missing = [
1424
1425
1426
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
Patrick von Platen's avatar
Patrick von Platen committed
1427
1428
    ]
    _keys_to_ignore_on_load_unexpected = [
1429
1430
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]
1431

thomwolf's avatar
thomwolf committed
1432
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1433
        super().__init__(config)
1434
        self.model_dim = config.d_model
thomwolf's avatar
thomwolf committed
1435

1436
1437
1438
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
1439
        encoder_config.is_decoder = False
1440
        encoder_config.use_cache = False
1441
        encoder_config.is_encoder_decoder = False
1442
        self.encoder = T5Stack(encoder_config, self.shared)
1443
1444
1445

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1446
        decoder_config.is_encoder_decoder = False
1447
        decoder_config.num_layers = config.num_decoder_layers
1448
        self.decoder = T5Stack(decoder_config, self.shared)
1449

1450
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
1451
1452
1453

        self.init_weights()

1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.decoder.first_device)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1482
1483
1484
1485
1486
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
1487
1488
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
1489

1490
1491
1492
    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

thomwolf's avatar
thomwolf committed
1493
1494
1495
    def get_output_embeddings(self):
        return self.lm_head

1496
1497
    def get_encoder(self):
        return self.encoder
thomwolf's avatar
thomwolf committed
1498

1499
1500
1501
    def get_decoder(self):
        return self.decoder

1502
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1503
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1504
1505
1506
1507
1508
1509
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1510
1511
        head_mask=None,
        decoder_head_mask=None,
1512
        cross_attn_head_mask=None,
1513
        encoder_outputs=None,
1514
        past_key_values=None,
1515
1516
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1517
1518
        labels=None,
        use_cache=None,
1519
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1520
        output_hidden_states=None,
1521
        return_dict=None,
1522
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1523
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1524
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1525
1526
1527
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``
Lysandre's avatar
Lysandre committed
1528
1529
1530
1531
1532
1533
1534
1535

        Returns:

        Examples::

            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration

            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
1536
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
1537
1538

            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
1539
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
1540
            >>> outputs = model(input_ids=input_ids, labels=labels)
Lysandre's avatar
Lysandre committed
1541
1542
1543
            >>> loss = outputs.loss
            >>> logits = outputs.logits

1544
            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids  # Batch size 1
Lysandre's avatar
Lysandre committed
1545
            >>> outputs = model.generate(input_ids)
Patrick von Platen's avatar
Patrick von Platen committed
1546
        """
1547
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1548
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1549

1550
1551
1552
1553
1554
1555
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

1556
        # Encode if needed (training, first prediction pass)
1557
        if encoder_outputs is None:
thomwolf's avatar
thomwolf committed
1558
            # Convert encoder inputs in embeddings if needed
1559
            encoder_outputs = self.encoder(
1560
1561
1562
1563
1564
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1565
                output_hidden_states=output_hidden_states,
1566
                return_dict=return_dict,
1567
            )
1568
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1569
1570
1571
1572
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1573
            )
thomwolf's avatar
thomwolf committed
1574

1575
        hidden_states = encoder_outputs[0]
1576

1577
1578
1579
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

Sylvain Gugger's avatar
Sylvain Gugger committed
1580
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1581
            # get decoder inputs from shifting lm labels to the right
Sylvain Gugger's avatar
Sylvain Gugger committed
1582
            decoder_input_ids = self._shift_right(labels)
1583

1584
1585
        # If decoding with past key value states, only the last tokens
        # should be given as an input
1586
        if past_key_values is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1587
            assert labels is None, "Decoder should not use cached key value states when training."
1588
1589
1590
1591
1592
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

1604
        # Decode
1605
1606
1607
1608
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1609
            past_key_values=past_key_values,
1610
1611
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
1612
            head_mask=decoder_head_mask,
1613
            cross_attn_head_mask=cross_attn_head_mask,
1614
            use_cache=use_cache,
1615
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1616
            output_hidden_states=output_hidden_states,
1617
            return_dict=return_dict,
1618
        )
1619
1620

        sequence_output = decoder_outputs[0]
Patrick von Platen's avatar
Patrick von Platen committed
1621

1622
1623
1624
1625
1626
1627
        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

Patrick von Platen's avatar
Patrick von Platen committed
1628
1629
1630
1631
1632
        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)

thomwolf's avatar
thomwolf committed
1633
        lm_logits = self.lm_head(sequence_output)
thomwolf's avatar
thomwolf committed
1634

1635
        loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1636
        if labels is not None:
Lysandre's avatar
Lysandre committed
1637
            loss_fct = CrossEntropyLoss(ignore_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1638
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1639
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
thomwolf's avatar
thomwolf committed
1640

1641
        if not return_dict:
1642
1643
1644
1645
1646
1647
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
1648
            past_key_values=decoder_outputs.past_key_values,
1649
1650
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1651
            cross_attentions=decoder_outputs.cross_attentions,
1652
1653
1654
1655
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
1656

1657
    def prepare_inputs_for_generation(
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
        self,
        input_ids,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs
1668
    ):
1669
1670
1671
1672
1673

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

1674
1675
        return {
            "decoder_input_ids": input_ids,
1676
            "past_key_values": past,
1677
1678
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
1679
1680
1681
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
1682
            "use_cache": use_cache,
1683
1684
        }

1685
1686
1687
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

1688
    def _reorder_cache(self, past, beam_idx):
1689
1690
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
1691
        if past is None:
1692
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1693
1694
1695
            return past

        reordered_decoder_past = ()
1696
        for layer_past_states in past:
1697
1698
1699
1700
1701
1702
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
1703
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1704
1705
1706
1707
1708
1709
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1710
        return reordered_decoder_past
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732


@add_start_docstrings(
    "The bare T5 Model transformer outputting encoder's raw hidden-states" "without any specific head on top.",
    T5_START_DOCSTRING,
)
class T5EncoderModel(T5PreTrainedModel):
    authorized_missing_keys = [
        r"encoder\.embed_tokens\.weight",
    ]

    def __init__(self, config: T5Config):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

        self.init_weights()

Lysandre Debut's avatar
Lysandre Debut committed
1733
1734
1735
1736
        # Model parallel
        self.model_parallel = False
        self.device_map = None

1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
    @add_start_docstrings(PARALLELIZE_DOCSTRING)
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.model_parallel = True

    @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
    def deparallelize(self):
        self.encoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()

1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)

    def get_encoder(self):
        return self.encoder

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Returns:

        Example::

            >>> from transformers import T5Tokenizer, T5EncoderModel
            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5EncoderModel.from_pretrained('t5-small')
            >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
            >>> outputs = model(input_ids=input_ids)
            >>> last_hidden_states = outputs.last_hidden_state
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        return encoder_outputs