modeling_t5.py 43.7 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch T5 model. """

from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import math
import os
import sys
thomwolf's avatar
thomwolf committed
24
import copy
thomwolf's avatar
thomwolf committed
25
import itertools
thomwolf's avatar
thomwolf committed
26
27
28
29
from io import open

import torch
from torch import nn
thomwolf's avatar
thomwolf committed
30
import torch.nn.functional as F
thomwolf's avatar
thomwolf committed
31
32
from torch.nn import CrossEntropyLoss, MSELoss

33
from .modeling_utils import PreTrainedModel, prune_linear_layer
thomwolf's avatar
thomwolf committed
34
from .configuration_t5 import T5Config
35
from .file_utils import add_start_docstrings, DUMMY_INPUTS, DUMMY_MASK
thomwolf's avatar
thomwolf committed
36
37
38
39
40
41
42
43

logger = logging.getLogger(__name__)

####################################################
# This dict contrains shortcut names and associated url
# for the pretrained weights provided with the models
####################################################
T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
44
    't5-small': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-pytorch_model.bin",
thomwolf's avatar
thomwolf committed
45
46
    't5-base': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-pytorch_model.bin",
    't5-large': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-pytorch_model.bin",
47
48
    't5-3b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-3b-pytorch_model.bin",
    't5-11b': "https://s3.amazonaws.com/models.huggingface.co/bert/t5-11b-pytorch_model.bin",
thomwolf's avatar
thomwolf committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
}

####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
    """ Load tf checkpoints in a pytorch model.
    """
    try:
        import re
        import numpy as np
        import tensorflow as tf
    except ImportError:
        logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions.")
        raise
    tf_path = os.path.abspath(tf_checkpoint_path)
    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    names = []
71
    tf_weights = {}
thomwolf's avatar
thomwolf committed
72
73
74
75
    for name, shape in init_vars:
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
76
        tf_weights[name] = array
thomwolf's avatar
thomwolf committed
77

78
79
    for txt_name in names:
        name = txt_name.split('/')
thomwolf's avatar
thomwolf committed
80
81
82
83
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
            logger.info("Skipping {}".format("/".join(name)))
84
85
86
87
88
            tf_weights.pop(txt_name, None)
            continue
        if '_slot_' in name[-1]:
            logger.info("Skipping {}".format("/".join(name)))
            tf_weights.pop(txt_name, None)
thomwolf's avatar
thomwolf committed
89
90
            continue
        pointer = model
91
        array = tf_weights[txt_name]
thomwolf's avatar
thomwolf committed
92
93
94
95
96
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
                l = re.split(r'_(\d+)', m_name)
            else:
                l = [m_name]
97
            if l[0] in ['kernel', 'scale', 'embedding']:
thomwolf's avatar
thomwolf committed
98
                pointer = getattr(pointer, 'weight')
99
100
101
102
103
104
            # elif l[0] == 'scale':
            #     pointer = getattr(pointer, 'weight')
            # elif l[0] == 'output_bias' or l[0] == 'beta':
            #     pointer = getattr(pointer, 'bias')
            # elif l[0] == 'squad':
            #     pointer = getattr(pointer, 'classifier')
thomwolf's avatar
thomwolf committed
105
106
107
108
109
110
111
112
113
            else:
                try:
                    pointer = getattr(pointer, l[0])
                except AttributeError:
                    logger.info("Skipping {}".format("/".join(name)))
                    continue
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
114
        if l[0] not in ['kernel', 'scale', 'embedding']:
thomwolf's avatar
thomwolf committed
115
            pointer = getattr(pointer, 'weight')
116
117
        if l[0] != 'embedding':
            logger.info("Transposing numpy weight of shape {} for {}".format(array.shape, name))
thomwolf's avatar
thomwolf committed
118
119
120
121
122
123
124
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        logger.info("Initialize PyTorch weight {}".format(name))
125
126
127
128
129
        pointer.data = torch.from_numpy(array.astype(np.float32))
        tf_weights.pop(txt_name, None)

    logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
    # logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
thomwolf's avatar
thomwolf committed
130
131
132
133
134
135
136
137
138
    return model


####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
####################################################

thomwolf's avatar
thomwolf committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """ Construct a layernorm module in the T5 style
            No bias and no substraction of mean.
        """
        super(T5LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x / torch.sqrt(variance + self.variance_epsilon)
        return self.weight * x


thomwolf's avatar
thomwolf committed
154
class T5DenseReluDense(nn.Module):
thomwolf's avatar
thomwolf committed
155
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
156
157
158
        super(T5DenseReluDense, self).__init__()
        self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
thomwolf's avatar
thomwolf committed
159
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
160
161
162
163
164
165
166
167
168
169
170
171
172

    def forward(self, hidden_states):
        h = self.wi(hidden_states)
        h = F.relu(h)
        h = self.dropout(h)
        h = self.wo(h)
        return h


class T5LayerFF(nn.Module):
    def __init__(self, config):
        super(T5LayerFF, self).__init__()
        self.DenseReluDense = T5DenseReluDense(config)
thomwolf's avatar
thomwolf committed
173
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
174
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
175
176
177
178
179
180
181
182
183
184
185

    def forward(self, hidden_states):
        norm_x = self.layer_norm(hidden_states)
        y = self.DenseReluDense(norm_x)
        layer_output = hidden_states + self.dropout(y)
        return layer_output


class T5Attention(nn.Module):
    NEW_ID = itertools.count()

thomwolf's avatar
thomwolf committed
186
    def __init__(self, config, has_relative_attention_bias=False):
thomwolf's avatar
thomwolf committed
187
188
        super(T5Attention, self).__init__()
        self.layer_id = next(T5Attention.NEW_ID)
thomwolf's avatar
thomwolf committed
189
        self.is_decoder = config.is_decoder
thomwolf's avatar
thomwolf committed
190
        self.has_relative_attention_bias = has_relative_attention_bias
thomwolf's avatar
thomwolf committed
191
192
193

        self.output_attentions = config.output_attentions
        self.relative_attention_num_buckets = config.relative_attention_num_buckets
194
        self.d_model = config.d_model
195
        self.d_kv = config.d_kv
thomwolf's avatar
thomwolf committed
196
197
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
198
        self.inner_dim = self.n_heads * self.d_kv
thomwolf's avatar
thomwolf committed
199

200
        # Mesh TensorFlow initialization to avoid scaling before softmax
201
202
203
204
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
thomwolf's avatar
thomwolf committed
205

thomwolf's avatar
thomwolf committed
206
207
        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
thomwolf's avatar
thomwolf committed
208
209
210
211
212
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
213
        mask = torch.ones(self.n_heads, self.d_kv)
thomwolf's avatar
thomwolf committed
214
215
216
217
218
219
220
221
222
223
224
225
226
        heads = set(heads) - self.pruned_heads
        for head in heads:
            head -= sum(1 if h < head else 0 for h in self.pruned_heads)
            mask[head] = 0
        mask = mask.view(-1).contiguous().eq(1)
        index = torch.arange(len(mask))[mask].long()
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
227
        self.inner_dim = self.d_kv * self.n_heads
thomwolf's avatar
thomwolf committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
    def _relative_position_bucket(relative_position,
                                  bidirectional=True,
                                  num_buckets=32,
                                  max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention.
        The relative position is defined as memory_position - query_position, i.e.
        the distance in tokens from the attending position to the attended-to
        position.  If bidirectional=False, then positive relative positions are
        invalid.
        We use smaller buckets for small absolute relative_position and larger buckets
        for larger absolute relative_positions.  All relative positions >=max_distance
        map to the same bucket.  All relative positions <=-max_distance map to the
        same bucket.  This should allow for more graceful generalization to longer
        sequences than the model has been trained on.
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer
        Returns:
            a Tensor with the same shape as relative_position, containing int32
            values in the range [0, num_buckets)
        """
        ret = 0
        n = -relative_position
        if bidirectional:
            num_buckets //= 2
            ret += (n < 0).to(torch.long) * num_buckets  # mtf.to_int32(mtf.less(n, 0)) * num_buckets
            n = torch.abs(n)
        else:
thomwolf's avatar
thomwolf committed
265
            n = torch.max(n, torch.zeros_like(n))
thomwolf's avatar
thomwolf committed
266
267
268
269
270
271
272
273
274
275
        # now n is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
        is_small = (n < max_exact)

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
        val_if_large = max_exact + (
            torch.log(n.float() / max_exact)
            / math.log(max_distance / max_exact) * (num_buckets - max_exact)).to(torch.long)
thomwolf's avatar
thomwolf committed
276
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
thomwolf's avatar
thomwolf committed
277
278
279
280
281
282
283
284
285

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def compute_bias(self, qlen, klen):
        """ Compute binned relative position bias """
        context_position = torch.arange(qlen, dtype=torch.long)[:, None]
        memory_position = torch.arange(klen, dtype=torch.long)[None, :]
        relative_position = memory_position - context_position  # shape (qlen, klen)
thomwolf's avatar
thomwolf committed
286
        rp_bucket = self._relative_position_bucket(relative_position,  # shape (qlen, klen)
thomwolf's avatar
thomwolf committed
287
288
289
290
291
292
                                                   bidirectional=not self.is_decoder,
                                                   num_buckets=self.relative_attention_num_buckets)
        values = self.relative_attention_bias(rp_bucket)  # shape (qlen, klen, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
        return values

thomwolf's avatar
thomwolf committed
293
    def forward(self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None):
thomwolf's avatar
thomwolf committed
294
295
296
297
298
299
300
301
302
303
304
305
306
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        """
        # Input is (bs, qlen, dim)
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        bs, qlen, dim = input.size()
        if kv is None:
            klen = qlen if cache is None else cache['slen'] + qlen
        else:
            klen = kv.size(1)

        def shape(x):
            """  projection """
307
            return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2)
thomwolf's avatar
thomwolf committed
308
309
310

        def unshape(x):
            """  compute context """
311
            return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim)
thomwolf's avatar
thomwolf committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

        q = shape(self.q(input))                                          # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k(input))                                      # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v(input))                                      # (bs, n_heads, qlen, dim_per_head)
        elif cache is None or self.layer_id not in cache:
            k = v = kv
            k = shape(self.k(k))                                          # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v(v))                                          # (bs, n_heads, qlen, dim_per_head)

        if cache is not None:
            if self.layer_id in cache:
                if kv is None:
                    k_, v_ = cache[self.layer_id]
                    k = torch.cat([k_, k], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                    v = torch.cat([v_, v], dim=2)                             # (bs, n_heads, klen, dim_per_head)
                else:
                    k, v = cache[self.layer_id]
            cache[self.layer_id] = (k, v)

        # q = q / math.sqrt(dim_per_head)                                     # No scaling in T5
thomwolf's avatar
thomwolf committed
333
        scores = torch.einsum('bnqd,bnkd->bnqk', q, k)                        # (bs, n_heads, qlen, klen)
thomwolf's avatar
thomwolf committed
334
335

        if position_bias is None:
thomwolf's avatar
thomwolf committed
336
337
            if not self.has_relative_attention_bias:
                raise ValueError("No position_bias provided and no weights to compute position_bias")
thomwolf's avatar
thomwolf committed
338
            position_bias = self.compute_bias(qlen, klen)
thomwolf's avatar
thomwolf committed
339
            if mask is not None:
thomwolf's avatar
thomwolf committed
340
                position_bias = position_bias + mask                          # (bs, n_heads, qlen, klen)
thomwolf's avatar
thomwolf committed
341

thomwolf's avatar
thomwolf committed
342
        scores += position_bias
thomwolf's avatar
thomwolf committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        weights = F.softmax(scores.float(), dim=-1).type_as(scores)           # (bs, n_heads, qlen, klen)
        weights = F.dropout(weights, p=self.dropout, training=self.training)  # (bs, n_heads, qlen, klen)

        # Mask heads if we want to
        if head_mask is not None:
            weights = weights * head_mask

        context = torch.matmul(weights, v)                                    # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)                                            # (bs, qlen, dim)

        context = self.o(context)

        outputs = (context,)
        if self.output_attentions:
            outputs = outputs + (weights,)
thomwolf's avatar
thomwolf committed
358
359
        if self.has_relative_attention_bias:
            outputs = outputs + (position_bias,)
thomwolf's avatar
thomwolf committed
360
        return outputs
thomwolf's avatar
thomwolf committed
361
362
363


class T5LayerSelfAttention(nn.Module):
thomwolf's avatar
thomwolf committed
364
    def __init__(self, config, has_relative_attention_bias=False):
thomwolf's avatar
thomwolf committed
365
        super(T5LayerSelfAttention, self).__init__()
thomwolf's avatar
thomwolf committed
366
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
367
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
368
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
369

thomwolf's avatar
thomwolf committed
370
    def forward(self, hidden_states, attention_mask=None, position_bias=None, head_mask=None):
thomwolf's avatar
thomwolf committed
371
372
        norm_x = self.layer_norm(hidden_states)
        attention_output = self.SelfAttention(norm_x,
thomwolf's avatar
thomwolf committed
373
                                              mask=attention_mask,
thomwolf's avatar
thomwolf committed
374
                                              position_bias=position_bias,
thomwolf's avatar
thomwolf committed
375
376
377
378
                                              head_mask=head_mask)
        y = attention_output[0]
        layer_output = hidden_states + self.dropout(y)
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
thomwolf's avatar
thomwolf committed
379
        return outputs
thomwolf's avatar
thomwolf committed
380
381


thomwolf's avatar
thomwolf committed
382
class T5LayerCrossAttention(nn.Module):
thomwolf's avatar
thomwolf committed
383
    def __init__(self, config, has_relative_attention_bias=False):
thomwolf's avatar
thomwolf committed
384
        super(T5LayerCrossAttention, self).__init__()
thomwolf's avatar
thomwolf committed
385
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
thomwolf's avatar
thomwolf committed
386
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
387
        self.dropout = nn.Dropout(config.dropout_rate)
thomwolf's avatar
thomwolf committed
388

thomwolf's avatar
thomwolf committed
389
    def forward(self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None):
thomwolf's avatar
thomwolf committed
390
391
        norm_x = self.layer_norm(hidden_states)
        attention_output = self.EncDecAttention(norm_x,
thomwolf's avatar
thomwolf committed
392
                                                mask=attention_mask,
thomwolf's avatar
thomwolf committed
393
                                                kv=kv,
thomwolf's avatar
thomwolf committed
394
                                                position_bias=position_bias,
thomwolf's avatar
thomwolf committed
395
396
397
398
399
400
401
402
                                                head_mask=head_mask)
        y = attention_output[0]
        layer_output = hidden_states + self.dropout(y)
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
        return outputs


class T5Block(nn.Module):
thomwolf's avatar
thomwolf committed
403
    def __init__(self, config, has_relative_attention_bias=False):
thomwolf's avatar
thomwolf committed
404
405
        super(T5Block, self).__init__()
        self.is_decoder = config.is_decoder
406
407
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
thomwolf's avatar
thomwolf committed
408
        if self.is_decoder:
409
410
            self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias))
            self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
411
        else:
412
            self.layer.append(T5LayerFF(config))
thomwolf's avatar
thomwolf committed
413

thomwolf's avatar
thomwolf committed
414
415
416
    def forward(self, hidden_states, attention_mask=None, position_bias=None,
                encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None,
                head_mask=None):
417
        self_attention_outputs = self.layer[0](hidden_states,
thomwolf's avatar
thomwolf committed
418
                                                attention_mask=attention_mask,
thomwolf's avatar
thomwolf committed
419
                                                position_bias=position_bias,
thomwolf's avatar
thomwolf committed
420
421
                                                head_mask=head_mask)
        hidden_states = self_attention_outputs[0]
thomwolf's avatar
thomwolf committed
422
        outputs = self_attention_outputs[1:]  # Keep self-attention outputs and relative position weights
thomwolf's avatar
thomwolf committed
423

424
425
426
427
428
429
430
431
        if not self.is_decoder:
            hidden_states = self.layer[1](hidden_states)
        else:
            cross_attention_outputs = self.layer[1](hidden_states,
                                                    kv=encoder_hidden_states,
                                                    attention_mask=encoder_attention_mask,
                                                    position_bias=encoder_decoder_position_bias,
                                                    head_mask=head_mask)
thomwolf's avatar
thomwolf committed
432
            hidden_states = cross_attention_outputs[0]
thomwolf's avatar
thomwolf committed
433
            outputs = outputs + cross_attention_outputs[1:]  # Keep cross-attention outputs and relative position weights
434
            hidden_states = self.layer[2](hidden_states)
thomwolf's avatar
thomwolf committed
435
436

        outputs = (hidden_states,) + outputs  # add attentions if we output them
thomwolf's avatar
thomwolf committed
437
        return outputs  # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
thomwolf's avatar
thomwolf committed
438
439


thomwolf's avatar
thomwolf committed
440
441
442
443
444
445
446
447
448
class T5PreTrainedModel(PreTrainedModel):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    config_class = T5Config
    pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
    load_tf_weights = load_tf_weights_in_t5
    base_model_prefix = "transformer"

449
450
451
452
453
454
455
456
457
    @property
    def dummy_inputs(self):
        input_ids = torch.tensor(DUMMY_INPUTS)
        input_mask = torch.tensor(DUMMY_MASK)
        dummy_inputs = {'decoder_input_ids': input_ids,
                        'encoder_input_ids': input_ids,
                        'decoder_attention_mask': input_mask}
        return dummy_inputs

thomwolf's avatar
thomwolf committed
458
459
    def _init_weights(self, module):
        """ Initialize the weights """
460
        factor = self.config.initializer_factor  # Used for testing weights initialization
thomwolf's avatar
thomwolf committed
461
        if isinstance(module, T5LayerNorm):
462
            module.weight.data.fill_(factor*1.0)
463
        elif isinstance(module, (T5Model, T5WithLMHeadModel)):
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
            module.shared.weight.data.normal_(mean=0.0, std=factor*1.0)
        elif isinstance(module, T5DenseReluDense):
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
            module.wi.weight.data.normal_(mean=0.0, std=factor*((self.config.d_model) ** -0.5))
            if hasattr(module.wi, 'bias') and module.wi.bias is not None:
                module.wi.bias.data.zero_()
            module.wo.weight.data.normal_(mean=0.0, std=factor*((self.config.d_ff) ** -0.5))
            if hasattr(module.wo, 'bias') and module.wo.bias is not None:
                module.wo.bias.data.zero_()
        elif isinstance(module, T5Attention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
            d_kv = self.config.d_kv
            n_heads = self.config.num_heads
            module.q.weight.data.normal_(mean=0.0, std=factor*((d_model * d_kv) ** -0.5))
            module.k.weight.data.normal_(mean=0.0, std=factor*(d_model ** -0.5))
            module.v.weight.data.normal_(mean=0.0, std=factor*(d_model ** -0.5))
            module.o.weight.data.normal_(mean=0.0, std=factor*((n_heads * d_kv) ** -0.5))
            if module.has_relative_attention_bias:
                module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor*((d_model) ** -0.5))
thomwolf's avatar
thomwolf committed
489
490
491


class T5Stack(T5PreTrainedModel):
thomwolf's avatar
thomwolf committed
492
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
493
494
495
496
497
        super(T5Stack, self).__init__(config)
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.is_decoder = config.is_decoder

498
499
        self.block = nn.ModuleList([T5Block(config, has_relative_attention_bias=bool(i == 0))
                                    for i in range(config.num_layers)])
thomwolf's avatar
thomwolf committed
500
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
501
502
503
        self.dropout = nn.Dropout(config.dropout_rate)

        self.init_weights()
thomwolf's avatar
thomwolf committed
504
505
506
507
508
509
510
511

    def forward(self,
                hidden_states,
                attention_mask=None,
                encoder_hidden_states=None,
                encoder_attention_mask=None,
                head_mask=None):

thomwolf's avatar
thomwolf committed
512
        batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
thomwolf's avatar
thomwolf committed
513
        if attention_mask is None:
thomwolf's avatar
thomwolf committed
514
            attention_mask = torch.ones(batch_size, seq_length).to(hidden_states.device)
thomwolf's avatar
thomwolf committed
515
516
        if self.is_decoder and encoder_attention_mask is None:
            encoder_seq_length = encoder_hidden_states.shape[1]
thomwolf's avatar
thomwolf committed
517
            encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(hidden_states.device)
thomwolf's avatar
thomwolf committed
518
519
520
521
522

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
523
        elif attention_mask.dim() == 2:
thomwolf's avatar
thomwolf committed
524
525
526
527
        # Provided a padding mask of dimensions [batch_size, seq_length]
        # - if the model is a decoder, apply a causal mask in addition to the padding mask
        # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
thomwolf's avatar
thomwolf committed
528
                seq_ids = torch.arange(seq_length, device=hidden_states.device)
thomwolf's avatar
thomwolf committed
529
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
thomwolf's avatar
thomwolf committed
530
                causal_mask = causal_mask.to(attention_mask)
thomwolf's avatar
thomwolf committed
531
532
533
534
535
536
                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
            else:
                extended_attention_mask = attention_mask[:, None, None, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
thomwolf's avatar
thomwolf committed
537
        # positions we want to attend and -1e9 for masked positions.
thomwolf's avatar
thomwolf committed
538
539
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
thomwolf's avatar
thomwolf committed
540

541
        # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
thomwolf's avatar
thomwolf committed
542
        # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
543
544
        # extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2))

thomwolf's avatar
thomwolf committed
545
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
thomwolf's avatar
thomwolf committed
546
        extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
thomwolf's avatar
thomwolf committed
547

thomwolf's avatar
thomwolf committed
548
549
550
551
552
553
554
555
        if self.is_decoder:
            # If a 2D ou 3D attention mask is provided for the cross-attention
            # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
            if encoder_attention_mask.dim() == 3:
                encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
            if encoder_attention_mask.dim() == 2:
                encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]

556
557
558
559
            # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
            # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
            # encoder_extended_attention_mask = (encoder_extended_attention_mask == encoder_extended_attention_mask.transpose(-1, -2))

thomwolf's avatar
thomwolf committed
560
            encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
thomwolf's avatar
thomwolf committed
561
            encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
thomwolf's avatar
thomwolf committed
562
563
        else:
            encoder_extended_attention_mask = None
thomwolf's avatar
thomwolf committed
564
565
566
567
568
569
570
571
572

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
thomwolf's avatar
thomwolf committed
573
                head_mask = head_mask.expand(self.config.num_layers, -1, -1, -1, -1)
thomwolf's avatar
thomwolf committed
574
575
576
577
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
        else:
thomwolf's avatar
thomwolf committed
578
            head_mask = [None] * self.config.num_layers
thomwolf's avatar
thomwolf committed
579
580
581
582

        all_hidden_states = ()
        all_attentions = ()
        position_bias = None
thomwolf's avatar
thomwolf committed
583
        encoder_decoder_position_bias = None
thomwolf's avatar
thomwolf committed
584
585

        hidden_states = self.dropout(hidden_states)
586
        for i, layer_module in enumerate(self.block):
thomwolf's avatar
thomwolf committed
587
588
589
590
591
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(hidden_states,
                                         attention_mask=extended_attention_mask,
thomwolf's avatar
thomwolf committed
592
                                         position_bias=position_bias,
thomwolf's avatar
thomwolf committed
593
594
                                         encoder_hidden_states=encoder_hidden_states,
                                         encoder_attention_mask=encoder_extended_attention_mask,
thomwolf's avatar
thomwolf committed
595
                                         encoder_decoder_position_bias=encoder_decoder_position_bias,
thomwolf's avatar
thomwolf committed
596
                                         head_mask=head_mask[i])
thomwolf's avatar
thomwolf committed
597
598
            # layer_outputs is a tuple with:
            # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
thomwolf's avatar
thomwolf committed
599
            hidden_states = layer_outputs[0]
thomwolf's avatar
thomwolf committed
600
            if i == 0:
thomwolf's avatar
thomwolf committed
601
                # We share the position biases between the layers - the first layer store them
602
                # layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
thomwolf's avatar
thomwolf committed
603
604
605
                position_bias = layer_outputs[2 if self.output_attentions else 1]
                if self.is_decoder:
                    encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
thomwolf's avatar
thomwolf committed
606
607

            if self.output_attentions:
thomwolf's avatar
thomwolf committed
608
                all_attentions = all_attentions + (layer_outputs[1],)  # We keep only self-attention weights for now
thomwolf's avatar
thomwolf committed
609
610
611
612
613
614
615
616
617
618
619
620
621

        hidden_states = self.final_layer_norm(hidden_states)
        layer_output = self.dropout(hidden_states)

        # Add last layer
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            outputs = outputs + (all_attentions,)
thomwolf's avatar
thomwolf committed
622
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
623
624


thomwolf's avatar
thomwolf committed
625
626
627
T5_START_DOCSTRING = r"""    The T5 model was proposed in
    `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`_
    by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
628
    It's an encoder decoder transformer pre-trained in a text-to-text denoising generative setting.
thomwolf's avatar
thomwolf committed
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.

    .. _`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`:
        https://arxiv.org/abs/1910.10683

    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module

    Parameters:
        config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model. 
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""

T5_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
            To match pre-training, T5 input sequence should be formatted with [CLS] and [SEP] tokens as follows:

            (a) For sequence pairs:

                ``tokens:         [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``

            (b) For single sequences:

                ``tokens:         [CLS] the dog is hairy . [SEP]``

            T5 is a model with relative position embeddings so you should be able to pad the inputs on
            the right or the left.

            Indices can be obtained using :class:`transformers.T5Tokenizer`.
            See :func:`transformers.PreTrainedTokenizer.encode` and
            :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
        **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
            Mask to nullify selected heads of the self-attention modules.
            Mask values selected in ``[0, 1]``:
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""

675
@add_start_docstrings("The bare T5 Model transformer outputting raw hidden-states"
thomwolf's avatar
thomwolf committed
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
                      "without any specific head on top.",
                      T5_START_DOCSTRING, T5_INPUTS_DOCSTRING)
class T5Model(T5PreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the output of the last layer of the model.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

693
694
        tokenizer = T5Tokenizer.from_pretrained('t5-small')
        model = T5Model.from_pretrained('t5-small')
thomwolf's avatar
thomwolf committed
695
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
696
        outputs = model(input_ids=input_ids)
thomwolf's avatar
thomwolf committed
697
698
699
700
701
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

    """
    def __init__(self, config):
        super(T5Model, self).__init__(config)
thomwolf's avatar
thomwolf committed
702
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
thomwolf's avatar
thomwolf committed
703
704
705

        encoder_config = copy.deepcopy(config)
        self.encoder = T5Stack(encoder_config)
thomwolf's avatar
thomwolf committed
706

thomwolf's avatar
thomwolf committed
707
708
709
        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        self.decoder = T5Stack(decoder_config)
thomwolf's avatar
thomwolf committed
710
711
712
713

        self.init_weights()

    def get_input_embeddings(self):
thomwolf's avatar
thomwolf committed
714
        return self.shared
thomwolf's avatar
thomwolf committed
715
716

    def set_input_embeddings(self, new_embeddings):
thomwolf's avatar
thomwolf committed
717
        self.shared = new_embeddings
thomwolf's avatar
thomwolf committed
718
719
720
721
722
723
724
725
726

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

thomwolf's avatar
thomwolf committed
727
    def forward(self, **kwargs):
thomwolf's avatar
thomwolf committed
728
729
730
731
732
733
734
        # keyword arguments come in 3 flavors: encoder-specific (prefixed by
        # `encoder_`), decoder-specific (prefixed by `decoder_`) and those
        # that apply to the model as whole.
        # We let the specific kwargs override the common ones in case of conflict.
        kwargs_common = dict((k, v) for k, v in kwargs.items()
                             if not k.startswith("encoder_") and not k.startswith("decoder_"))
        kwargs_encoder = kwargs_common.copy()
thomwolf's avatar
thomwolf committed
735
736
737
        kwargs_decoder = kwargs_common.copy()
        kwargs_encoder.update(dict((k[len("encoder_"):], v) for k, v in kwargs.items() if k.startswith("encoder_")))
        kwargs_decoder.update(dict((k[len("decoder_"):], v) for k, v in kwargs.items() if k.startswith("decoder_")))
thomwolf's avatar
thomwolf committed
738
739
740

        # Encode if needed (training, first prediction pass)
        encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
thomwolf's avatar
thomwolf committed
741
        encoder_attention_mask = kwargs_encoder.get("attention_mask", None)
thomwolf's avatar
thomwolf committed
742
        if encoder_hidden_states is None:
thomwolf's avatar
thomwolf committed
743
744
745
746
747
            # Convert encoder inputs in embeddings if needed
            hidden_states = kwargs_encoder.pop("inputs_embeds", None)
            if hidden_states is None:
                encoder_inputs_ids = kwargs_encoder.pop("input_ids")
                hidden_states = self.shared(encoder_inputs_ids)  # Convert inputs in embeddings
thomwolf's avatar
thomwolf committed
748
749
750
751
752
753

            if encoder_attention_mask is not None:
                # Apply masking
                encoder_attention_mask = (encoder_attention_mask != 0).to(hidden_states)
                hidden_states = hidden_states * encoder_attention_mask.unsqueeze(-1)

thomwolf's avatar
thomwolf committed
754
755
            encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
            encoder_hidden_states = encoder_outputs[0]
thomwolf's avatar
thomwolf committed
756
        else:
thomwolf's avatar
thomwolf committed
757
            encoder_outputs = ()
thomwolf's avatar
thomwolf committed
758

thomwolf's avatar
thomwolf committed
759
        # Decode
thomwolf's avatar
thomwolf committed
760
761
762
763
764
765
        # Convert decoder inputs in embeddings if needed
        hidden_states = kwargs_decoder.pop("inputs_embeds", None)
        if hidden_states is None:
            decoder_inputs_ids = kwargs_decoder.pop("input_ids")
            hidden_states = self.shared(decoder_inputs_ids)

thomwolf's avatar
thomwolf committed
766
        kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
thomwolf's avatar
thomwolf committed
767
        kwargs_decoder["encoder_attention_mask"] = encoder_attention_mask
thomwolf's avatar
thomwolf committed
768
        decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
thomwolf's avatar
thomwolf committed
769

thomwolf's avatar
thomwolf committed
770
        return decoder_outputs + encoder_outputs
thomwolf's avatar
thomwolf committed
771
772
773
774


@add_start_docstrings("""T5 Model with a `language modeling` head on top. """,
    T5_START_DOCSTRING, T5_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
775
class T5WithLMHeadModel(T5PreTrainedModel):
thomwolf's avatar
thomwolf committed
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
    r"""
        **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Masked language modeling loss.
        **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

798
799
        tokenizer = T5Tokenizer.from_pretrained('t5-small')
        model = T5WithLMHeadModel.from_pretrained('t5-small')
thomwolf's avatar
thomwolf committed
800
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
801
        outputs = model(input_ids=input_ids, lm_labels=input_ids)
thomwolf's avatar
thomwolf committed
802
803
804
805
        loss, prediction_scores = outputs[:2]

    """
    def __init__(self, config):
thomwolf's avatar
thomwolf committed
806
        super(T5WithLMHeadModel, self).__init__(config)
807
        self.model_dim = config.d_model
thomwolf's avatar
thomwolf committed
808

809
810
811
812
813
814
815
816
817
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        self.encoder = T5Stack(encoder_config)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        self.decoder = T5Stack(decoder_config)

818
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
thomwolf's avatar
thomwolf committed
819
820
821

        self.init_weights()

822
823
824
825
826
827
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings

thomwolf's avatar
thomwolf committed
828
829
830
    def get_output_embeddings(self):
        return self.lm_head

thomwolf's avatar
thomwolf committed
831
    def forward(self, **kwargs):
832
833
834
835
836
        # keyword arguments come in 3 flavors: encoder-specific (prefixed by
        # `encoder_`), decoder-specific (prefixed by `decoder_`) and those
        # that apply to the model as whole.
        # We let the specific kwargs override the common ones in case of conflict.

thomwolf's avatar
thomwolf committed
837
        lm_labels = kwargs.pop('decoder_lm_labels', None)
thomwolf's avatar
thomwolf committed
838

839
840
841
842
843
844
845
846
847
848
        kwargs_common = dict((k, v) for k, v in kwargs.items()
                             if not k.startswith("encoder_") and not k.startswith("decoder_"))
        kwargs_encoder = kwargs_common.copy()
        kwargs_decoder = kwargs_common.copy()
        kwargs_encoder.update(dict((k[len("encoder_"):], v) for k, v in kwargs.items() if k.startswith("encoder_")))
        kwargs_decoder.update(dict((k[len("decoder_"):], v) for k, v in kwargs.items() if k.startswith("decoder_")))

        # Encode if needed (training, first prediction pass)
        encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
        if encoder_hidden_states is None:
thomwolf's avatar
thomwolf committed
849
850
851
852
853
854
            # Convert encoder inputs in embeddings if needed
            hidden_states = kwargs_encoder.pop("inputs_embeds", None)
            if hidden_states is None:
                encoder_inputs_ids = kwargs_encoder.pop("input_ids")
                hidden_states = self.shared(encoder_inputs_ids)  # Convert inputs in embeddings

855
856
857
858
859
860
            encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
            encoder_hidden_states = encoder_outputs[0]
        else:
            encoder_outputs = ()

        # Decode
thomwolf's avatar
thomwolf committed
861
862
863
864
865
866
        # Convert decoder inputs in embeddings if needed
        hidden_states = kwargs_decoder.pop("inputs_embeds", None)
        if hidden_states is None:
            decoder_inputs_ids = kwargs_decoder.pop("input_ids")
            hidden_states = self.shared(decoder_inputs_ids)

867
868
869
870
871
        kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
        kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
        decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)

        sequence_output = decoder_outputs[0]
872
873
874
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        sequence_output = sequence_output * (self.model_dim ** -0.5)
thomwolf's avatar
thomwolf committed
875
        lm_logits = self.lm_head(sequence_output)
thomwolf's avatar
thomwolf committed
876

877
        decoder_outputs = (lm_logits,) + decoder_outputs[1:]  # Add hidden states and attention if they are here
thomwolf's avatar
thomwolf committed
878
879
880
881
882
883
        if lm_labels is not None:
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1))
884
            decoder_outputs = (loss,) + decoder_outputs  # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
thomwolf's avatar
thomwolf committed
885

886
        return decoder_outputs + encoder_outputs