modeling_t5.py 58 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch T5 model. """


Aymeric Augustin's avatar
Aymeric Augustin committed
18
import copy
thomwolf's avatar
thomwolf committed
19
20
import math
import os
Sylvain Gugger's avatar
Sylvain Gugger committed
21
import warnings
thomwolf's avatar
thomwolf committed
22
23

import torch
thomwolf's avatar
thomwolf committed
24
import torch.nn.functional as F
Aymeric Augustin's avatar
Aymeric Augustin committed
25
from torch import nn
26
from torch.nn import CrossEntropyLoss
thomwolf's avatar
thomwolf committed
27

Sylvain Gugger's avatar
Sylvain Gugger committed
28
from ...file_utils import (
29
30
31
    DUMMY_INPUTS,
    DUMMY_MASK,
    add_start_docstrings,
32
    add_start_docstrings_to_model_forward,
33
34
    replace_return_docstrings,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
35
from ...modeling_outputs import (
36
37
38
39
40
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
41
42
43
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_t5 import T5Config
Aymeric Augustin's avatar
Aymeric Augustin committed
44

thomwolf's avatar
thomwolf committed
45

Lysandre Debut's avatar
Lysandre Debut committed
46
logger = logging.get_logger(__name__)
thomwolf's avatar
thomwolf committed
47

48
_CONFIG_FOR_DOC = "T5Config"
49
50
_TOKENIZER_FOR_DOC = "T5Tokenizer"

thomwolf's avatar
thomwolf committed
51
####################################################
52
# This dict contains shortcut names and associated url
thomwolf's avatar
thomwolf committed
53
54
# for the pretrained weights provided with the models
####################################################
55
56
57
58
59
60
61
62
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "t5-small",
    "t5-base",
    "t5-large",
    "t5-3b",
    "t5-11b",
    # See all T5 models at https://huggingface.co/models?filter=t5
]
thomwolf's avatar
thomwolf committed
63

64

thomwolf's avatar
thomwolf committed
65
66
67
68
69
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
Lysandre's avatar
Lysandre committed
70
    """Load tf checkpoints in a pytorch model."""
thomwolf's avatar
thomwolf committed
71
72
    try:
        import re
73

thomwolf's avatar
thomwolf committed
74
75
76
        import numpy as np
        import tensorflow as tf
    except ImportError:
77
78
79
80
        logger.error(
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
thomwolf's avatar
thomwolf committed
81
82
83
84
85
86
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
87
    tf_weights = {}
thomwolf's avatar
thomwolf committed
88
89
90
91
    for name, shape in init_vars:
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
92
        tf_weights[name] = array
thomwolf's avatar
thomwolf committed
93

94
    for txt_name in names:
95
        name = txt_name.split("/")
thomwolf's avatar
thomwolf committed
96
97
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
98
99
100
101
        if any(
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
            for n in name
        ):
thomwolf's avatar
thomwolf committed
102
            logger.info("Skipping {}".format("/".join(name)))
103
104
            tf_weights.pop(txt_name, None)
            continue
105
        if "_slot_" in name[-1]:
106
107
            logger.info("Skipping {}".format("/".join(name)))
            tf_weights.pop(txt_name, None)
thomwolf's avatar
thomwolf committed
108
109
            continue
        pointer = model
110
        array = tf_weights[txt_name]
Patrick von Platen's avatar
Patrick von Platen committed
111

thomwolf's avatar
thomwolf committed
112
        for m_name in name:
113
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
114
                scope_names = re.split(r"_(\d+)", m_name)
thomwolf's avatar
thomwolf committed
115
            else:
116
117
                scope_names = [m_name]
            if scope_names[0] in ["kernel", "scale", "embedding"]:
118
                pointer = getattr(pointer, "weight")
Patrick von Platen's avatar
Patrick von Platen committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            elif scope_names[0] == "self_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[0]
            elif scope_names[0] == "enc_dec_attention":
                pointer = getattr(pointer, "layer")
                pointer = pointer[1]
            elif scope_names[0] == "dense_relu_dense":
                pointer = getattr(pointer, "layer")
                pointer = pointer[2]
            elif scope_names[0] == "rms_norm":
                if hasattr(pointer, "layer_norm"):
                    pointer = getattr(pointer, "layer_norm")
                elif hasattr(pointer, "final_layer_norm"):
                    pointer = getattr(pointer, "final_layer_norm")
133
134
135
136
137
138
            elif scope_names[0] == "scale":
                pointer = getattr(pointer, "weight")
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
                pointer = getattr(pointer, "bias")
            elif scope_names[0] == "squad":
                pointer = getattr(pointer, "classifier")
Patrick von Platen's avatar
Patrick von Platen committed
139
140
141
142
            elif scope_names[0] == "decoder" and name[1] == "logits":
                continue
            elif scope_names[0] == "logits":
                pointer = getattr(pointer, "lm_head")
thomwolf's avatar
thomwolf committed
143
144
            else:
                try:
145
                    pointer = getattr(pointer, scope_names[0])
thomwolf's avatar
thomwolf committed
146
147
148
                except AttributeError:
                    logger.info("Skipping {}".format("/".join(name)))
                    continue
149
150
            if len(scope_names) >= 2:
                num = int(scope_names[1])
thomwolf's avatar
thomwolf committed
151
                pointer = pointer[num]
152
        if scope_names[0] not in ["kernel", "scale", "embedding"]:
153
            pointer = getattr(pointer, "weight")
154
        if scope_names[0] != "embedding":
155
            logger.info("Transposing numpy weight of shape {} for {}".format(array.shape, name))
thomwolf's avatar
thomwolf committed
156
157
            array = np.transpose(array)
        try:
Teven's avatar
Teven committed
158
159
160
            assert (
                pointer.shape == array.shape
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
thomwolf's avatar
thomwolf committed
161
162
163
164
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        logger.info("Initialize PyTorch weight {}".format(name))
165
166
167
        pointer.data = torch.from_numpy(array.astype(np.float32))
        tf_weights.pop(txt_name, None)

168
    logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
thomwolf's avatar
thomwolf committed
169
170
171
172
173
174
175
176
177
    return model


####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
####################################################

178

thomwolf's avatar
thomwolf committed
179
180
class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
Sylvain Gugger's avatar
Sylvain Gugger committed
181
        """
182
        Construct a layernorm module in the T5 style No bias and no subtraction of mean.
thomwolf's avatar
thomwolf committed
183
        """
Julien Chaumond's avatar
Julien Chaumond committed
184
        super().__init__()
thomwolf's avatar
thomwolf committed
185
186
187
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

188
    def forward(self, hidden_states):
189
        # layer norm should always be calculated in float32
190
191
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
192

193
        # convert into float16 if necessary
194
        if self.weight.dtype == torch.float16:
195
196
            hidden_states = hidden_states.to(torch.float16)
        return self.weight * hidden_states
thomwolf's avatar
thomwolf committed
197
198


thomwolf's avatar
thomwolf committed
199
class T5DenseReluDense(nn.Module):
thomwolf's avatar
thomwolf committed
200
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
201
        super().__init__()
thomwolf's avatar
thomwolf committed
202
203
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
thomwolf's avatar
thomwolf committed
204
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
205
206

    def forward(self, hidden_states):
207
208
209
210
211
        hidden_states = self.wi(hidden_states)
        hidden_states = F.relu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
212
213
214
215


class T5LayerFF(nn.Module):
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
216
        super().__init__()
thomwolf's avatar
thomwolf committed
217
        self.DenseReluDense = T5DenseReluDense(config)
thomwolf's avatar
thomwolf committed
218
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
219
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
220
221

    def forward(self, hidden_states):
222
223
224
225
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return hidden_states
thomwolf's avatar
thomwolf committed
226
227
228


class T5Attention(nn.Module):
229
    def __init__(self, config: T5Config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
230
        super().__init__()
thomwolf's avatar
thomwolf committed
231
        self.is_decoder = config.is_decoder
thomwolf's avatar
thomwolf committed
232
        self.has_relative_attention_bias = has_relative_attention_bias
thomwolf's avatar
thomwolf committed
233
234

        self.relative_attention_num_buckets = config.relative_attention_num_buckets
235
        self.d_model = config.d_model
236
        self.key_value_proj_dim = config.d_kv
thomwolf's avatar
thomwolf committed
237
238
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
239
        self.inner_dim = self.n_heads * self.key_value_proj_dim
thomwolf's avatar
thomwolf committed
240

241
        # Mesh TensorFlow initialization to avoid scaling before softmax
242
243
244
245
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
thomwolf's avatar
thomwolf committed
246

thomwolf's avatar
thomwolf committed
247
248
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
thomwolf's avatar
thomwolf committed
249
250
251
252
253
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
254
255
256
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
thomwolf's avatar
thomwolf committed
257
258
259
260
261
262
263
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
264
        self.inner_dim = self.key_value_proj_dim * self.n_heads
thomwolf's avatar
thomwolf committed
265
266
267
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
268
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
thomwolf's avatar
thomwolf committed
269
270
271
272
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

Sylvain Gugger's avatar
Sylvain Gugger committed
273
274
275
276
277
278
279
        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

thomwolf's avatar
thomwolf committed
280
281
282
283
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
284
            max_distance: an integer
Sylvain Gugger's avatar
Sylvain Gugger committed
285

thomwolf's avatar
thomwolf committed
286
        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
287
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
thomwolf's avatar
thomwolf committed
288
        """
289
        relative_buckets = 0
thomwolf's avatar
thomwolf committed
290
291
        if bidirectional:
            num_buckets //= 2
292
293
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
thomwolf's avatar
thomwolf committed
294
        else:
295
296
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)
thomwolf's avatar
thomwolf committed
297
298
299

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
300
        is_small = relative_position < max_exact
thomwolf's avatar
thomwolf committed
301
302

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
303
304
305
306
        relative_postion_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
307
        ).to(torch.long)
308
309
310
        relative_postion_if_large = torch.min(
            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
        )
thomwolf's avatar
thomwolf committed
311

312
313
        relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
        return relative_buckets
thomwolf's avatar
thomwolf committed
314

315
    def compute_bias(self, query_length, key_length):
thomwolf's avatar
thomwolf committed
316
        """ Compute binned relative position bias """
317
318
319
320
321
322
        context_position = torch.arange(query_length, dtype=torch.long)[:, None]
        memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
323
324
            num_buckets=self.relative_attention_num_buckets,
        )
325
326
327
        relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
thomwolf's avatar
thomwolf committed
328
329
        return values

330
331
    def forward(
        self,
332
        hidden_states,
333
        mask=None,
334
        key_value_states=None,
335
        position_bias=None,
336
        past_key_value=None,
337
338
        head_mask=None,
        query_length=None,
339
        use_cache=False,
340
        output_attentions=False,
341
    ):
thomwolf's avatar
thomwolf committed
342
        """
343
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
thomwolf's avatar
thomwolf committed
344
        """
345
346
347
348
349
350
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length
351

352
        if past_key_value is not None:
353
            assert (
354
355
356
                len(past_key_value) == 2
            ), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
                len(past_key_value)
357
            )
358
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
359

360
        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
thomwolf's avatar
thomwolf committed
361

362
        def shape(states):
thomwolf's avatar
thomwolf committed
363
            """  projection """
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
            """  reshape """
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
            """ projects hidden states correctly to key/query states """
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
380

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
thomwolf's avatar
thomwolf committed
401

402
        # compute scores
Abel's avatar
Abel committed
403
        scores = torch.matmul(
404
405
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
thomwolf's avatar
thomwolf committed
406
407

        if position_bias is None:
thomwolf's avatar
thomwolf committed
408
            if not self.has_relative_attention_bias:
409
410
411
412
413
                position_bias = torch.zeros(
                    (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
                )
            else:
                position_bias = self.compute_bias(real_seq_length, key_length)
414
415
416

            # if key and values are already calculated
            # we want only the last query position bias
417
            if past_key_value is not None:
418
                position_bias = position_bias[:, :, -seq_length:, :]
419

thomwolf's avatar
thomwolf committed
420
            if mask is not None:
421
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
422

thomwolf's avatar
thomwolf committed
423
        scores += position_bias
424
425
426
427
428
429
        attn_weights = F.softmax(scores.float(), dim=-1).type_as(
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
        attn_weights = F.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)
thomwolf's avatar
thomwolf committed
430
431
432

        # Mask heads if we want to
        if head_mask is not None:
433
            attn_weights = attn_weights * head_mask
thomwolf's avatar
thomwolf committed
434

435
436
        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)
thomwolf's avatar
thomwolf committed
437

438
439
        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
440

441
        if output_attentions:
442
            outputs = outputs + (attn_weights,)
thomwolf's avatar
thomwolf committed
443
        return outputs
thomwolf's avatar
thomwolf committed
444
445
446


class T5LayerSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
447
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
448
        super().__init__()
449
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
450
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
451
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
452

453
    def forward(
454
455
456
457
458
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        head_mask=None,
459
        past_key_value=None,
460
        use_cache=False,
461
        output_attentions=False,
462
    ):
463
        normed_hidden_states = self.layer_norm(hidden_states)
464
        attention_output = self.SelfAttention(
465
            normed_hidden_states,
466
467
468
            mask=attention_mask,
            position_bias=position_bias,
            head_mask=head_mask,
469
            past_key_value=past_key_value,
470
            use_cache=use_cache,
471
            output_attentions=output_attentions,
472
        )
473
474
        hidden_states = hidden_states + self.dropout(attention_output[0])
        outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them
thomwolf's avatar
thomwolf committed
475
        return outputs
thomwolf's avatar
thomwolf committed
476
477


thomwolf's avatar
thomwolf committed
478
class T5LayerCrossAttention(nn.Module):
479
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
480
        super().__init__()
481
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
thomwolf's avatar
thomwolf committed
482
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
483
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
484

485
486
487
    def forward(
        self,
        hidden_states,
488
        key_value_states,
489
490
491
        attention_mask=None,
        position_bias=None,
        head_mask=None,
492
        past_key_value=None,
493
        use_cache=False,
494
        query_length=None,
495
        output_attentions=False,
496
    ):
497
        normed_hidden_states = self.layer_norm(hidden_states)
498
        attention_output = self.EncDecAttention(
499
            normed_hidden_states,
500
            mask=attention_mask,
501
            key_value_states=key_value_states,
502
503
            position_bias=position_bias,
            head_mask=head_mask,
504
            past_key_value=past_key_value,
505
            use_cache=use_cache,
506
            query_length=query_length,
507
            output_attentions=output_attentions,
508
        )
509
        layer_output = hidden_states + self.dropout(attention_output[0])
thomwolf's avatar
thomwolf committed
510
511
512
513
514
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


class T5Block(nn.Module):
thomwolf's avatar
thomwolf committed
515
    def __init__(self, config, has_relative_attention_bias=False):
Julien Chaumond's avatar
Julien Chaumond committed
516
        super().__init__()
thomwolf's avatar
thomwolf committed
517
        self.is_decoder = config.is_decoder
518
519
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
thomwolf's avatar
thomwolf committed
520
        if self.is_decoder:
521
            self.layer.append(T5LayerCrossAttention(config))
522
523

        self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
524

525
526
527
528
529
530
531
532
533
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        head_mask=None,
534
        past_key_value=None,
535
        use_cache=False,
536
        output_attentions=False,
537
        return_dict=True,
538
    ):
539

540
541
542
        if past_key_value is not None:
            assert self.is_decoder, "Only decoder can use `past_key_values`"
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
543
544

            error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
545
546
547
                expected_num_past_key_values,
                "2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "",
                len(past_key_value),
548
            )
549
            assert len(past_key_value) == expected_num_past_key_values, error_message
550

551
552
            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
553
        else:
554
            self_attn_past_key_value, cross_attn_past_key_value = None, None
555

556
        self_attention_outputs = self.layer[0](
557
558
559
560
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
            head_mask=head_mask,
561
            past_key_value=self_attn_past_key_value,
562
            use_cache=use_cache,
563
            output_attentions=output_attentions,
564
        )
565
566
567
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

568
569
        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
570
571
572
573
574
575
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None
thomwolf's avatar
thomwolf committed
576

577
578
            cross_attention_outputs = self.layer[1](
                hidden_states,
579
                key_value_states=encoder_hidden_states,
580
581
582
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
                head_mask=head_mask,
583
                past_key_value=cross_attn_past_key_value,
584
                query_length=query_length,
585
                use_cache=use_cache,
586
                output_attentions=output_attentions,
587
            )
thomwolf's avatar
thomwolf committed
588
            hidden_states = cross_attention_outputs[0]
589
590
591
592
593
594
595
596
597
598
            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)
        outputs = (hidden_states,)
thomwolf's avatar
thomwolf committed
599

600
601
        outputs = outputs + (present_key_value_state,) + attention_outputs
        return outputs  # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
thomwolf's avatar
thomwolf committed
602
603


thomwolf's avatar
thomwolf committed
604
class T5PreTrainedModel(PreTrainedModel):
Sylvain Gugger's avatar
Sylvain Gugger committed
605
606
607
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
thomwolf's avatar
thomwolf committed
608
    """
609

thomwolf's avatar
thomwolf committed
610
611
612
613
    config_class = T5Config
    load_tf_weights = load_tf_weights_in_t5
    base_model_prefix = "transformer"

614
615
616
617
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
618
619
        dummy_inputs = {
            "decoder_input_ids": input_ids,
620
            "input_ids": input_ids,
621
622
            "decoder_attention_mask": input_mask,
        }
623
624
        return dummy_inputs

thomwolf's avatar
thomwolf committed
625
626
    def _init_weights(self, module):
        """ Initialize the weights """
627
        factor = self.config.initializer_factor  # Used for testing weights initialization
thomwolf's avatar
thomwolf committed
628
        if isinstance(module, T5LayerNorm):
629
            module.weight.data.fill_(factor * 1.0)
630
        elif isinstance(module, (T5Model, T5ForConditionalGeneration)):
631
632
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
633
            module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
634
635
636
637
        elif isinstance(module, T5DenseReluDense):
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
638
639
            module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
            if hasattr(module.wi, "bias") and module.wi.bias is not None:
640
                module.wi.bias.data.zero_()
641
642
            module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
643
644
645
646
647
                module.wo.bias.data.zero_()
        elif isinstance(module, T5Attention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
648
            key_value_proj_dim = self.config.d_kv
649
            n_heads = self.config.num_heads
650
            module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
651
652
            module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
            module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
653
            module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
654
            if module.has_relative_attention_bias:
655
                module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
thomwolf's avatar
thomwolf committed
656

657
658
659
660
661
662
663
664
665
666
667
668
669
670
    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

        assert (
            decoder_start_token_id is not None
        ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"

        # shift inputs to the right
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
        shifted_input_ids[..., 0] = decoder_start_token_id

        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
Sylvain Gugger's avatar
Sylvain Gugger committed
671
        # replace possible -100 values in labels by `pad_token_id`
672
673
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

674
        assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
675
676
677

        return shifted_input_ids

thomwolf's avatar
thomwolf committed
678
679

class T5Stack(T5PreTrainedModel):
680
    def __init__(self, config, embed_tokens=None):
Julien Chaumond's avatar
Julien Chaumond committed
681
        super().__init__(config)
682
683

        self.embed_tokens = embed_tokens
thomwolf's avatar
thomwolf committed
684
685
        self.is_decoder = config.is_decoder

686
687
688
        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
        )
thomwolf's avatar
thomwolf committed
689
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
690
691
692
        self.dropout = nn.Dropout(config.dropout_rate)

        self.init_weights()
thomwolf's avatar
thomwolf committed
693

694
695
696
697
698
699
700
701
702
    def get_input_embeddings(self):
        return self.embed_tokens

    def get_output_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings

703
704
    def forward(
        self,
705
        input_ids=None,
706
707
708
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
709
        inputs_embeds=None,
710
        head_mask=None,
711
        past_key_values=None,
712
        use_cache=None,
713
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
714
        output_hidden_states=None,
715
        return_dict=None,
716
    ):
thomwolf's avatar
thomwolf committed
717

718
        use_cache = use_cache if use_cache is not None else self.config.use_cache
719
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Joseph Liu's avatar
Joseph Liu committed
720
721
722
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
723
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
724

725
        if input_ids is not None and inputs_embeds is not None:
726
727
728
729
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(
                f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
            )
730
731
732
733
734
735
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
736
737
            err_msg_prefix = "decoder_" if self.is_decoder else ""
            raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
738
739

        if inputs_embeds is None:
740
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
741
742
743
744
            inputs_embeds = self.embed_tokens(input_ids)

        batch_size, seq_length = input_shape

745
746
        # required mask seq length can be calculated via length of past
        mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
747

748
        if use_cache is True:
749
750
751
            assert self.is_decoder, ":obj:`use_cache` can only be set to `True` if {} is used as a decoder".format(
                self
            )
752

thomwolf's avatar
thomwolf committed
753
        if attention_mask is None:
754
755
            attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
thomwolf's avatar
thomwolf committed
756
            encoder_seq_length = encoder_hidden_states.shape[1]
757
758
759
            encoder_attention_mask = torch.ones(
                batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
            )
thomwolf's avatar
thomwolf committed
760

761
762
763
        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)
764

thomwolf's avatar
thomwolf committed
765
        # ourselves in which case we just need to make it broadcastable to all heads.
766
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
thomwolf's avatar
thomwolf committed
767

768
        if self.is_decoder and encoder_attention_mask is not None:
769
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
thomwolf's avatar
thomwolf committed
770
771
        else:
            encoder_extended_attention_mask = None
thomwolf's avatar
thomwolf committed
772
773

        # Prepare head mask if needed
774
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
775
776
777
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
778
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
thomwolf's avatar
thomwolf committed
779
        position_bias = None
thomwolf's avatar
thomwolf committed
780
        encoder_decoder_position_bias = None
thomwolf's avatar
thomwolf committed
781

782
        hidden_states = self.dropout(inputs_embeds)
783

784
        for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
Joseph Liu's avatar
Joseph Liu committed
785
            if output_hidden_states:
thomwolf's avatar
thomwolf committed
786
787
                all_hidden_states = all_hidden_states + (hidden_states,)

788
789
790
791
792
793
794
795
            layer_outputs = layer_module(
                hidden_states,
                attention_mask=extended_attention_mask,
                position_bias=position_bias,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_extended_attention_mask,
                encoder_decoder_position_bias=encoder_decoder_position_bias,
                head_mask=head_mask[i],
796
                past_key_value=past_key_value,
797
                use_cache=use_cache,
798
                output_attentions=output_attentions,
799
            )
thomwolf's avatar
thomwolf committed
800
            # layer_outputs is a tuple with:
801
802
            # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
            hidden_states, present_key_value_state = layer_outputs[:2]
803

804
805
806
807
808
809
            # We share the position biases between the layers - the first layer store them
            # layer_outputs = hidden-states, key-value-states (self-attention weights),
            # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
            position_bias = layer_outputs[2]
            if self.is_decoder and encoder_hidden_states is not None:
                encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
810
            # append next layer key value states
811
812
            if use_cache:
                present_key_value_states = present_key_value_states + (present_key_value_state,)
thomwolf's avatar
thomwolf committed
813

814
            if output_attentions:
815
                all_attentions = all_attentions + (layer_outputs[3],)
816
                if self.is_decoder:
817
                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
thomwolf's avatar
thomwolf committed
818
819

        hidden_states = self.final_layer_norm(hidden_states)
thomwolf's avatar
thomwolf committed
820
        hidden_states = self.dropout(hidden_states)
thomwolf's avatar
thomwolf committed
821
822

        # Add last layer
Joseph Liu's avatar
Joseph Liu committed
823
        if output_hidden_states:
thomwolf's avatar
thomwolf committed
824
825
            all_hidden_states = all_hidden_states + (hidden_states,)

826
        if not return_dict:
827
828
            return tuple(
                v
829
830
831
832
833
834
835
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
836
837
                if v is not None
            )
838
        return BaseModelOutputWithPastAndCrossAttentions(
839
840
841
842
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
843
            cross_attentions=all_cross_attentions,
844
        )
thomwolf's avatar
thomwolf committed
845
846


847
T5_START_DOCSTRING = r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
848

849
850
    The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
    <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
Sylvain Gugger's avatar
Sylvain Gugger committed
851
852
    Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text
    denoising generative setting.
thomwolf's avatar
thomwolf committed
853

Sylvain Gugger's avatar
Sylvain Gugger committed
854
855
856
857
    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
    pruning heads etc.)

Sylvain Gugger's avatar
Sylvain Gugger committed
858
859
860
    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
    general usage and behavior.
thomwolf's avatar
thomwolf committed
861
862

    Parameters:
863
        config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
Sylvain Gugger's avatar
Sylvain Gugger committed
864
865
866
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
            weights.
thomwolf's avatar
thomwolf committed
867
868
869
"""

T5_INPUTS_DOCSTRING = r"""
Patrick von Platen's avatar
Patrick von Platen committed
870
871
    Args:
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
872
873
            Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
            should be able to pad the inputs on both the right and the left.
Sylvain Gugger's avatar
Sylvain Gugger committed
874

Sylvain Gugger's avatar
Sylvain Gugger committed
875
876
877
            Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
            detail.
Sylvain Gugger's avatar
Sylvain Gugger committed
878

Sylvain Gugger's avatar
Sylvain Gugger committed
879
880
            To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
            <./t5.html#training>`__.
881
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
882
            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
Sylvain Gugger's avatar
Sylvain Gugger committed
883
884

            - 1 for tokens that are **not masked**,
885
            - 0 for tokens that are **masked**.
Sylvain Gugger's avatar
Sylvain Gugger committed
886
887

            `What are attention masks? <../glossary.html#attention-mask>`__
888
        decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
889
            Provide for sequence to sequence training. T5 uses the :obj:`pad_token_id` as the starting token for
Sylvain Gugger's avatar
Sylvain Gugger committed
890
891
            :obj:`decoder_input_ids` generation. If :obj:`past_key_values` is used, optionally only the last
            :obj:`decoder_input_ids` have to be input (see :obj:`past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
892

Sylvain Gugger's avatar
Sylvain Gugger committed
893
894
895
            To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
            <./t5.html#training>`__. If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset,
            :obj:`decoder_input_ids` takes the value of :obj:`input_ids`.
896
        decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
897
898
            Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
            also be used by default.
899
        encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
900
901
902
903
            Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
            `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
            sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
904
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Sylvain Gugger's avatar
Sylvain Gugger committed
905
906
907
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
908
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
Sylvain Gugger's avatar
Sylvain Gugger committed
909
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
910
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
911
            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
912
913
914
915

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

916
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Patrick von Platen's avatar
Patrick von Platen committed
917
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
Sylvain Gugger's avatar
Sylvain Gugger committed
918
919
            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
            vectors than the model's internal embedding lookup matrix.
920
        decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
921
            Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
Sylvain Gugger's avatar
Sylvain Gugger committed
922
923
924
            representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
            have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
            :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
Sylvain Gugger's avatar
Sylvain Gugger committed
925

Sylvain Gugger's avatar
Sylvain Gugger committed
926
927
            If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
            takes the value of :obj:`inputs_embeds`.
Sylvain Gugger's avatar
Sylvain Gugger committed
928

929
930
931
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
Sylvain Gugger's avatar
Sylvain Gugger committed
932

933
        output_attentions (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
934
935
            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
            tensors for more detail.
936
        output_hidden_states (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
937
938
            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
            more detail.
939
        return_dict (:obj:`bool`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
940
            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
thomwolf's avatar
thomwolf committed
941
942
"""

943
944
945
946
947

@add_start_docstrings(
    "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
    T5_START_DOCSTRING,
)
thomwolf's avatar
thomwolf committed
948
class T5Model(T5PreTrainedModel):
949
950
951
952
953
954
    authorized_missing_keys = [
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

955
    def __init__(self, config: T5Config):
Julien Chaumond's avatar
Julien Chaumond committed
956
        super().__init__(config)
thomwolf's avatar
thomwolf committed
957
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
thomwolf's avatar
thomwolf committed
958
959

        encoder_config = copy.deepcopy(config)
960
        encoder_config.use_cache = False
961
        encoder_config.is_encoder_decoder = False
962
        self.encoder = T5Stack(encoder_config, self.shared)
thomwolf's avatar
thomwolf committed
963

thomwolf's avatar
thomwolf committed
964
965
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
966
        decoder_config.is_encoder_decoder = False
967
        decoder_config.num_layers = config.num_decoder_layers
968
        self.decoder = T5Stack(decoder_config, self.shared)
thomwolf's avatar
thomwolf committed
969
970
971
972

        self.init_weights()

    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
973
        return self.shared
thomwolf's avatar
thomwolf committed
974
975

    def set_input_embeddings(self, new_embeddings):
thomwolf's avatar
thomwolf committed
976
        self.shared = new_embeddings
977
978
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
thomwolf's avatar
thomwolf committed
979

980
981
982
983
984
985
    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

thomwolf's avatar
thomwolf committed
986
    def _prune_heads(self, heads_to_prune):
Sylvain Gugger's avatar
Sylvain Gugger committed
987
988
989
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
thomwolf's avatar
thomwolf committed
990
991
992
993
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

994
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
995
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
996
997
998
999
1000
1001
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1002
        encoder_outputs=None,
1003
        past_key_values=None,
1004
        head_mask=None,
1005
1006
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1007
        use_cache=None,
1008
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1009
        output_hidden_states=None,
1010
        return_dict=None,
1011
        **kwargs,
1012
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1013
        r"""
Lysandre's avatar
Lysandre committed
1014
        Returns:
Patrick von Platen's avatar
Patrick von Platen committed
1015

Lysandre's avatar
Lysandre committed
1016
        Example::
1017

Lysandre's avatar
Lysandre committed
1018
            >>> from transformers import T5Tokenizer, T5Model
Patrick von Platen's avatar
Patrick von Platen committed
1019

Lysandre's avatar
Lysandre committed
1020
1021
            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5Model.from_pretrained('t5-small')
Patrick von Platen's avatar
Patrick von Platen committed
1022

1023
1024
            >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
            >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
1025
            >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
Patrick von Platen's avatar
Patrick von Platen committed
1026

1027
            >>> last_hidden_states = outputs.last_hidden_state
Patrick von Platen's avatar
Patrick von Platen committed
1028
        """
1029
1030
        if "decoder_past_key_value_states" in kwargs:
            warnings.warn(
1031
                "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
1032
1033
                FutureWarning,
            )
1034
1035
1036
1037
1038
1039
1040
            past_key_values = kwargs.pop("decoder_past_key_value_states")
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_values")
1041
1042
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

1043
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1044
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
thomwolf's avatar
thomwolf committed
1045
1046

        # Encode if needed (training, first prediction pass)
1047
1048
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
1049
1050
1051
1052
1053
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1054
                output_hidden_states=output_hidden_states,
1055
                return_dict=return_dict,
1056
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
1057
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1058
1059
1060
1061
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1062
            )
thomwolf's avatar
thomwolf committed
1063

1064
        hidden_states = encoder_outputs[0]
thomwolf's avatar
thomwolf committed
1065

1066
1067
1068
1069
1070
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1071
            past_key_values=past_key_values,
1072
1073
1074
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=head_mask,
1075
            use_cache=use_cache,
1076
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1077
            output_hidden_states=output_hidden_states,
1078
            return_dict=return_dict,
1079
        )
thomwolf's avatar
thomwolf committed
1080

1081
        if not return_dict:
1082
1083
1084
1085
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
1086
            past_key_values=decoder_outputs.past_key_values,
1087
1088
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1089
            cross_attentions=decoder_outputs.cross_attentions,
1090
1091
1092
1093
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
thomwolf's avatar
thomwolf committed
1094
1095


Patrick von Platen's avatar
Patrick von Platen committed
1096
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
1097
class T5ForConditionalGeneration(T5PreTrainedModel):
1098
1099
1100
1101
1102
1103
1104
1105
    authorized_missing_keys = [
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]
1106

thomwolf's avatar
thomwolf committed
1107
    def __init__(self, config):
Julien Chaumond's avatar
Julien Chaumond committed
1108
        super().__init__(config)
1109
        self.model_dim = config.d_model
thomwolf's avatar
thomwolf committed
1110

1111
1112
1113
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
1114
        encoder_config.use_cache = False
1115
        encoder_config.is_encoder_decoder = False
1116
        self.encoder = T5Stack(encoder_config, self.shared)
1117
1118
1119

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
1120
        decoder_config.is_encoder_decoder = False
1121
        decoder_config.num_layers = config.num_decoder_layers
1122
        self.decoder = T5Stack(decoder_config, self.shared)
1123

1124
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
1125
1126
1127

        self.init_weights()

1128
1129
1130
1131
1132
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
1133
1134
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)
1135

thomwolf's avatar
thomwolf committed
1136
1137
1138
    def get_output_embeddings(self):
        return self.lm_head

1139
1140
    def get_encoder(self):
        return self.encoder
thomwolf's avatar
thomwolf committed
1141

1142
1143
1144
    def get_decoder(self):
        return self.decoder

1145
    @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1146
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1147
1148
1149
1150
1151
1152
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
1153
        encoder_outputs=None,
1154
        past_key_values=None,
1155
        head_mask=None,
1156
1157
        inputs_embeds=None,
        decoder_inputs_embeds=None,
1158
1159
        labels=None,
        use_cache=None,
1160
        output_attentions=None,
Joseph Liu's avatar
Joseph Liu committed
1161
        output_hidden_states=None,
1162
        return_dict=None,
1163
        **kwargs,
1164
    ):
Patrick von Platen's avatar
Patrick von Platen committed
1165
        r"""
Sylvain Gugger's avatar
Sylvain Gugger committed
1166
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Sylvain Gugger's avatar
Sylvain Gugger committed
1167
1168
1169
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``
Sylvain Gugger's avatar
Sylvain Gugger committed
1170
1171
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
            Used to hide legacy arguments that have been deprecated.
Lysandre's avatar
Lysandre committed
1172
1173
1174
1175
1176
1177
1178
1179

        Returns:

        Examples::

            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration

            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
1180
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
1181
1182

            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
1183
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
1184
            >>> outputs = model(input_ids=input_ids, labels=labels)
Lysandre's avatar
Lysandre committed
1185
1186
1187
            >>> loss = outputs.loss
            >>> logits = outputs.logits

1188
            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids  # Batch size 1
Lysandre's avatar
Lysandre committed
1189
            >>> outputs = model.generate(input_ids)
Patrick von Platen's avatar
Patrick von Platen committed
1190
        """
1191

Sylvain Gugger's avatar
Sylvain Gugger committed
1192
1193
1194
        if "lm_labels" in kwargs:
            warnings.warn(
                "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1195
                FutureWarning,
Sylvain Gugger's avatar
Sylvain Gugger committed
1196
1197
            )
            labels = kwargs.pop("lm_labels")
1198
1199
        if "decoder_past_key_value_states" in kwargs:
            warnings.warn(
1200
1201
1202
1203
1204
1205
1206
                "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
                FutureWarning,
            )
            past_key_values = kwargs.pop("decoder_past_key_value_states")
        if "decoder_past_key_values" in kwargs:
            warnings.warn(
                "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
1207
1208
                FutureWarning,
            )
1209
            past_key_values = kwargs.pop("decoder_past_key_values")
Sylvain Gugger's avatar
Sylvain Gugger committed
1210
1211
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

1212
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1213
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1214

1215
        # Encode if needed (training, first prediction pass)
1216
        if encoder_outputs is None:
thomwolf's avatar
thomwolf committed
1217
            # Convert encoder inputs in embeddings if needed
1218
            encoder_outputs = self.encoder(
1219
1220
1221
1222
1223
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1224
                output_hidden_states=output_hidden_states,
1225
                return_dict=return_dict,
1226
            )
1227
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1228
1229
1230
1231
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1232
            )
thomwolf's avatar
thomwolf committed
1233

1234
        hidden_states = encoder_outputs[0]
1235

Sylvain Gugger's avatar
Sylvain Gugger committed
1236
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1237
            # get decoder inputs from shifting lm labels to the right
Sylvain Gugger's avatar
Sylvain Gugger committed
1238
            decoder_input_ids = self._shift_right(labels)
1239

1240
1241
        # If decoding with past key value states, only the last tokens
        # should be given as an input
1242
        if past_key_values is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
1243
            assert labels is None, "Decoder should not use cached key value states when training."
1244
1245
1246
1247
1248
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

1249
        # Decode
1250
1251
1252
1253
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
1254
            past_key_values=past_key_values,
1255
1256
1257
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=head_mask,
1258
            use_cache=use_cache,
1259
            output_attentions=output_attentions,
Joseph Liu's avatar
Joseph Liu committed
1260
            output_hidden_states=output_hidden_states,
1261
            return_dict=return_dict,
1262
        )
1263
1264

        sequence_output = decoder_outputs[0]
1265
1266
1267
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        sequence_output = sequence_output * (self.model_dim ** -0.5)
thomwolf's avatar
thomwolf committed
1268
        lm_logits = self.lm_head(sequence_output)
thomwolf's avatar
thomwolf committed
1269

1270
        loss = None
Sylvain Gugger's avatar
Sylvain Gugger committed
1271
        if labels is not None:
Lysandre's avatar
Lysandre committed
1272
            loss_fct = CrossEntropyLoss(ignore_index=-100)
Sylvain Gugger's avatar
Sylvain Gugger committed
1273
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1274
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
thomwolf's avatar
thomwolf committed
1275

1276
        if not return_dict:
1277
1278
1279
1280
1281
1282
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
1283
            past_key_values=decoder_outputs.past_key_values,
1284
1285
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
1286
            cross_attentions=decoder_outputs.cross_attentions,
1287
1288
1289
1290
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
1291

1292
1293
1294
    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
    ):
1295
1296
1297
1298
1299

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

1300
1301
        return {
            "decoder_input_ids": input_ids,
1302
            "past_key_values": past,
1303
1304
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
1305
            "use_cache": use_cache,
1306
1307
1308
        }

    def _reorder_cache(self, past, beam_idx):
1309
1310
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
1311
        if past is None:
1312
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1313
1314
1315
            return past

        reordered_decoder_past = ()
1316
        for layer_past_states in past:
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx),
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1330
        return reordered_decoder_past