modeling_transfo_xl.py 38.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model.
17
    Adapted from https://github.com/kimiyoung/transformer-xl.
thomwolf's avatar
thomwolf committed
18
19
20
    In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""

21
22
from __future__ import absolute_import, division, print_function, unicode_literals

thomwolf's avatar
thomwolf committed
23
24
25
26
27
import os
import json
import math
import logging
import collections
thomwolf's avatar
thomwolf committed
28
29
import sys
from io import open
thomwolf's avatar
thomwolf committed
30
31
32

import torch
import torch.nn as nn
33
import torch.nn.functional as F
thomwolf's avatar
thomwolf committed
34
35
36
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter

37
38
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
from .configuration_transfo_xl import TransfoXLConfig
thomwolf's avatar
thomwolf committed
39
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
40
from .file_utils import add_start_docstrings
thomwolf's avatar
thomwolf committed
41
42
43

logger = logging.getLogger(__name__)

44
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
45
46
    'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
}
47

48
49
50
51
52
def build_tf_to_pytorch_map(model, config):
    """ A map of modules from TF to PyTorch.
        This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
    """
    tf_to_pt_map = {}
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    if hasattr(model, 'transformer'):
        # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax
        tf_to_pt_map.update({
            "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
            "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
        for i, (out_l, proj_l, tie_proj) in enumerate(zip(
                                model.crit.out_layers,
                                model.crit.out_projs,
                                config.tie_projs)):
            layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
            if config.tie_weight:
                tf_to_pt_map.update({
                    layer_str + 'b': out_l.bias})
            else:
                raise NotImplementedError
                # I don't think this is implemented in the TF code
                tf_to_pt_map.update({
                    layer_str + 'lookup_table': out_l.weight,
                    layer_str + 'b': out_l.bias})
            if not tie_proj:
                tf_to_pt_map.update({
                    layer_str + 'proj': proj_l
                    })
        # Now load the rest of the transformer
        model = model.transformer

thomwolf's avatar
thomwolf committed
80
    # Embeddings
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
        layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
        tf_to_pt_map.update({
            layer_str + 'lookup_table': embed_l.weight,
            layer_str + 'proj_W': proj_l
            })

    # Transformer blocks
    for i, b in enumerate(model.layers):
        layer_str = "transformer/layer_%d/" % i
        tf_to_pt_map.update({
            layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
            layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
            layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
            layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
            layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
            layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
            layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
            layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
            layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
            layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
            layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
        })

    # Relative positioning biases
    if config.untie_r:
        r_r_list = []
        r_w_list = []
        for b in model.layers:
            r_r_list.append(b.dec_attn.r_r_bias)
            r_w_list.append(b.dec_attn.r_w_bias)
    else:
        r_r_list = [model.r_r_bias]
        r_w_list = [model.r_w_bias]
    tf_to_pt_map.update({
        'transformer/r_r_bias': r_r_list,
        'transformer/r_w_bias': r_w_list})
    return tf_to_pt_map

def load_tf_weights_in_transfo_xl(model, config, tf_path):
    """ Load tf checkpoints in a pytorch model
    """
123
124
125
    try:
        import numpy as np
        import tensorflow as tf
thomwolf's avatar
thomwolf committed
126
    except ImportError:
thomwolf's avatar
thomwolf committed
127
        logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
128
129
            "https://www.tensorflow.org/install/ for installation instructions.")
        raise
130
131
132
133
134
135
136
    # Build TF to PyTorch weights loading map
    tf_to_pt_map = build_tf_to_pytorch_map(model, config)

    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    tf_weights = {}
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
137
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        array = tf.train.load_variable(tf_path, name)
        tf_weights[name] = array

    for name, pointer in tf_to_pt_map.items():
        assert name in tf_weights
        array = tf_weights[name]
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if 'kernel' in name or 'proj' in name:
            array = np.transpose(array)
        if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
            # Here we will split the TF weigths
            assert len(pointer) == array.shape[0]
            for i, p_i in enumerate(pointer):
                arr_i = array[i, ...]
                try:
                    assert p_i.shape == arr_i.shape
                except AssertionError as e:
                    e.args += (p_i.shape, arr_i.shape)
                    raise
thomwolf's avatar
thomwolf committed
158
                logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
159
160
161
162
163
164
165
                p_i.data = torch.from_numpy(arr_i)
        else:
            try:
                assert pointer.shape == array.shape
            except AssertionError as e:
                e.args += (pointer.shape, array.shape)
                raise
thomwolf's avatar
thomwolf committed
166
            logger.info("Initialize PyTorch weight {}".format(name))
167
168
169
170
171
            pointer.data = torch.from_numpy(array)
        tf_weights.pop(name, None)
        tf_weights.pop(name + '/Adam', None)
        tf_weights.pop(name + '/Adam_1', None)

thomwolf's avatar
thomwolf committed
172
    logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
173
174
175
    return model


thomwolf's avatar
thomwolf committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class PositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()

        self.demb = demb

        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, pos_seq, bsz=None):
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

        if bsz is not None:
            return pos_emb[:,None,:].expand(-1, bsz, -1)
        else:
            return pos_emb[:,None,:]


thomwolf's avatar
thomwolf committed
195

thomwolf's avatar
thomwolf committed
196
class PositionwiseFF(nn.Module):
197
    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5):
thomwolf's avatar
thomwolf committed
198
199
200
201
202
203
204
205
206
207
208
209
210
        super(PositionwiseFF, self).__init__()

        self.d_model = d_model
        self.d_inner = d_inner
        self.dropout = dropout

        self.CoreNet = nn.Sequential(
            nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout),
        )

211
        self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

        self.pre_lnorm = pre_lnorm

    def forward(self, inp):
        if self.pre_lnorm:
            ##### layer normalization + positionwise feed-forward
            core_out = self.CoreNet(self.layer_norm(inp))

            ##### residual connection
            output = core_out + inp
        else:
            ##### positionwise feed-forward
            core_out = self.CoreNet(inp)

            ##### residual connection + layer normalization
            output = self.layer_norm(inp + core_out)

        return output

thomwolf's avatar
thomwolf committed
231

232
class RelPartialLearnableMultiHeadAttn(nn.Module):
thomwolf's avatar
thomwolf committed
233
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
thomwolf's avatar
thomwolf committed
234
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
235
236
                 r_r_bias=None, r_w_bias=None, output_attentions=False,
                 layer_norm_epsilon=1e-5):
237
        super(RelPartialLearnableMultiHeadAttn, self).__init__()
thomwolf's avatar
thomwolf committed
238

thomwolf's avatar
thomwolf committed
239
        self.output_attentions = output_attentions
thomwolf's avatar
thomwolf committed
240
241
242
243
244
245
246
247
248
249
250
        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

251
        self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
252
253
254
255
256

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

thomwolf's avatar
thomwolf committed
257
        if r_r_bias is None or r_w_bias is None: # Biases are not shared
thomwolf's avatar
thomwolf committed
258
259
            self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
            self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
thomwolf's avatar
thomwolf committed
260
261
262
263
        else:
            self.r_r_bias = r_r_bias
            self.r_w_bias = r_w_bias

264
        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
thomwolf's avatar
thomwolf committed
265

266
    def _rel_shift(self, x):
thomwolf's avatar
thomwolf committed
267
268
        zero_pad_shape = (x.size(0), 1) + x.size()[2:]
        zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
thomwolf's avatar
thomwolf committed
269
270
        x_padded = torch.cat([zero_pad, x], dim=1)

thomwolf's avatar
thomwolf committed
271
272
        x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
        x_padded = x_padded.view(*x_padded_shape)
thomwolf's avatar
thomwolf committed
273
274
275
276
277

        x = x_padded[1:].view_as(x)

        return x

thomwolf's avatar
thomwolf committed
278
    def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
thomwolf's avatar
thomwolf committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head

        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # qlen x n_head x d_head

        #### compute attention score
309
        rw_head_q = w_head_q + self.r_w_bias                                    # qlen x bsz x n_head x d_head
thomwolf's avatar
thomwolf committed
310
311
        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head

thomwolf's avatar
thomwolf committed
312
        rr_head_q = w_head_q + self.r_r_bias
thomwolf's avatar
thomwolf committed
313
314
315
316
317
318
319
320
        BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
321
322
        if attn_mask is not None and torch.sum(attn_mask).item():
            attn_mask = (attn_mask == 1)  # Switch to bool
thomwolf's avatar
thomwolf committed
323
            if attn_mask.dim() == 2:
324
325
326
327
328
329
                if next(self.parameters()).dtype == torch.float16:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[None,:,:,None], -65000).type_as(attn_score)
                else:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[None,:,:,None], -1e30).type_as(attn_score)
thomwolf's avatar
thomwolf committed
330
            elif attn_mask.dim() == 3:
331
332
333
334
335
336
                if next(self.parameters()).dtype == torch.float16:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[:,:,:,None], -65000).type_as(attn_score)
                else:
                    attn_score = attn_score.float().masked_fill(
                        attn_mask[:,:,:,None], -1e30).type_as(attn_score)
thomwolf's avatar
thomwolf committed
337
338
339
340
341

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

thomwolf's avatar
thomwolf committed
342
343
344
345
        # Mask heads if we want to
        if head_mask is not None:
            attn_prob = attn_prob * head_mask

thomwolf's avatar
thomwolf committed
346
347
348
349
350
351
352
353
354
355
356
357
358
        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
thomwolf's avatar
thomwolf committed
359
            outputs = [w + attn_out]
thomwolf's avatar
thomwolf committed
360
361
        else:
            ##### residual connection + layer normalization
thomwolf's avatar
thomwolf committed
362
            outputs = [self.layer_norm(w + attn_out)]
thomwolf's avatar
thomwolf committed
363

thomwolf's avatar
thomwolf committed
364
365
366
367
        if self.output_attentions:
            outputs.append(attn_prob)

        return outputs
thomwolf's avatar
thomwolf committed
368
369
370


class RelPartialLearnableDecoderLayer(nn.Module):
371
    def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5,
thomwolf's avatar
thomwolf committed
372
373
374
375
                 **kwargs):
        super(RelPartialLearnableDecoderLayer, self).__init__()

        self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
376
                            d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs)
thomwolf's avatar
thomwolf committed
377
        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 
378
379
                                     pre_lnorm=kwargs.get('pre_lnorm'),
                                     layer_norm_epsilon=layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
380

thomwolf's avatar
thomwolf committed
381
    def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
thomwolf's avatar
thomwolf committed
382

thomwolf's avatar
thomwolf committed
383
        attn_outputs = self.dec_attn(dec_inp, r,
thomwolf's avatar
thomwolf committed
384
                               attn_mask=dec_attn_mask,
thomwolf's avatar
thomwolf committed
385
386
387
388
389
390
                               mems=mems, head_mask=head_mask)
        ff_output = self.pos_ff(attn_outputs[0])

        outputs = [ff_output] + attn_outputs[1:]

        return outputs
thomwolf's avatar
thomwolf committed
391
392
393


class AdaptiveEmbedding(nn.Module):
394
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
thomwolf's avatar
thomwolf committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
                 sample_softmax=False):
        super(AdaptiveEmbedding, self).__init__()

        self.n_token = n_token
        self.d_embed = d_embed

        self.cutoffs = cutoffs + [n_token]
        self.div_val = div_val
        self.d_proj = d_proj

        self.emb_scale = d_proj ** 0.5

        self.cutoff_ends = [0] + self.cutoffs

        self.emb_layers = nn.ModuleList()
        self.emb_projs = nn.ParameterList()
        if div_val == 1:
            self.emb_layers.append(
                nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
            )
            if d_proj != d_embed:
thomwolf's avatar
thomwolf committed
416
                self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
thomwolf's avatar
thomwolf committed
417
418
419
420
421
        else:
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
                d_emb_i = d_embed // (div_val ** i)
                self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
thomwolf's avatar
thomwolf committed
422
                self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
thomwolf's avatar
thomwolf committed
423
424
425
426
427
428
429
430
431

    def forward(self, inp):
        if self.div_val == 1:
            embed = self.emb_layers[0](inp)
            if self.d_proj != self.d_embed:
                embed  = F.linear(embed, self.emb_projs[0])
        else:
            param = next(self.parameters())
            inp_flat = inp.view(-1)
432
            emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
thomwolf's avatar
thomwolf committed
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
                dtype=param.dtype, device=param.device)
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
                indices_i = mask_i.nonzero().squeeze()

                if indices_i.numel() == 0:
                    continue

                inp_i = inp_flat.index_select(0, indices_i) - l_idx
                emb_i = self.emb_layers[i](inp_i)
                emb_i = F.linear(emb_i, self.emb_projs[i])

                emb_flat.index_copy_(0, indices_i, emb_i)

thomwolf's avatar
thomwolf committed
449
450
            embed_shape = inp.size() + (self.d_proj,)
            embed = emb_flat.view(embed_shape)
thomwolf's avatar
thomwolf committed
451
452
453
454
455
456

        embed.mul_(self.emb_scale)

        return embed


457
class TransfoXLPreTrainedModel(PreTrainedModel):
458
459
460
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
461
    config_class = TransfoXLConfig
462
    pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
463
464
465
466
    load_tf_weights = load_tf_weights_in_transfo_xl
    base_model_prefix = "transformer"

    def _init_weight(self, weight):
467
468
469
470
        if self.config.init == 'uniform':
            nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
        elif self.config.init == 'normal':
            nn.init.normal_(weight, 0.0, self.config.init_std)
thomwolf's avatar
thomwolf committed
471

472
    def _init_bias(self, bias):
473
474
        nn.init.constant_(bias, 0.0)

475
    def _init_weights(self, m):
476
477
478
479
480
        """ Initialize the weights.
        """
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            if hasattr(m, 'weight') and m.weight is not None:
481
                self._init_weight(m.weight)
482
            if hasattr(m, 'bias') and m.bias is not None:
483
                self._init_bias(m.bias)
484
485
486
487
488
489
490
        elif classname.find('AdaptiveEmbedding') != -1:
            if hasattr(m, 'emb_projs'):
                for i in range(len(m.emb_projs)):
                    if m.emb_projs[i] is not None:
                        nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
        elif classname.find('Embedding') != -1:
            if hasattr(m, 'weight'):
491
                self._init_weight(m.weight)
492
493
        elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
            if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
494
                self._init_weight(m.cluster_weight)
495
            if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
496
                self._init_bias(m.cluster_bias)
497
498
499
500
501
502
503
504
            if hasattr(m, 'out_projs'):
                for i in range(len(m.out_projs)):
                    if m.out_projs[i] is not None:
                        nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
        elif classname.find('LayerNorm') != -1:
            if hasattr(m, 'weight'):
                nn.init.normal_(m.weight, 1.0, self.config.init_std)
            if hasattr(m, 'bias') and m.bias is not None:
505
                self._init_bias(m.bias)
506
        else:
507
            if hasattr(m, 'r_emb'):
508
                self._init_weight(m.r_emb)
509
            if hasattr(m, 'r_w_bias'):
510
                self._init_weight(m.r_w_bias)
511
            if hasattr(m, 'r_r_bias'):
512
                self._init_weight(m.r_r_bias)
513
            if hasattr(m, 'r_bias'):
514
                self._init_bias(m.r_bias)
thomwolf's avatar
thomwolf committed
515

516

thomwolf's avatar
thomwolf committed
517
518
519
520
521
522
TRANSFO_XL_START_DOCSTRING = r"""    The Transformer-XL model was proposed in
    `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
    by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
    It's a causal (uni-directional) transformer with relative positioning (sinuso茂dal) embeddings which can reuse
    previously computed hidden-states to attend to longer context (memory).
    This model also uses adaptive softmax inputs and outputs (tied).
thomwolf's avatar
thomwolf committed
523

thomwolf's avatar
thomwolf committed
524
525
    This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
    refer to the PyTorch documentation for all matter related to general usage and behavior.
thomwolf's avatar
thomwolf committed
526

thomwolf's avatar
thomwolf committed
527
528
    .. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`:
        https://arxiv.org/abs/1901.02860
529

thomwolf's avatar
thomwolf committed
530
531
    .. _`torch.nn.Module`:
        https://pytorch.org/docs/stable/nn.html#module
532

thomwolf's avatar
thomwolf committed
533
534
    Parameters:
        config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
535
536
            Initializing with a config file does not load the weights associated with the model, only the configuration.
            Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
thomwolf's avatar
thomwolf committed
537
"""
thomwolf's avatar
thomwolf committed
538

thomwolf's avatar
thomwolf committed
539
540
541
542
TRANSFO_XL_INPUTS_DOCSTRING = r"""
    Inputs:
        **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Indices of input sequence tokens in the vocabulary.
thomwolf's avatar
thomwolf committed
543
544
            Transformer-XL is a model with relative position embeddings so you can either pad the inputs on
            the right or on the left.
thomwolf's avatar
thomwolf committed
545
546
547
            Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`.
            See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
            :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
thomwolf's avatar
thomwolf committed
548
        **mems**: (`optional`)
thomwolf's avatar
thomwolf committed
549
550
551
            list of ``torch.FloatTensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `mems` output below). Can be used to speed up sequential decoding and attend to longer context.
thomwolf's avatar
thomwolf committed
552
        **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
thomwolf's avatar
thomwolf committed
553
            Mask to nullify selected heads of the self-attention modules.
thomwolf's avatar
thomwolf committed
554
            Mask values selected in ``[0, 1]``:
thomwolf's avatar
thomwolf committed
555
556
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
557

Julien Chaumond's avatar
Julien Chaumond committed
558
@add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
thomwolf's avatar
thomwolf committed
559
560
561
562
563
564
                      TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
class TransfoXLModel(TransfoXLPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the last layer of the model.
thomwolf's avatar
thomwolf committed
565
        **mems**:
thomwolf's avatar
thomwolf committed
566
567
568
569
570
571
572
            list of ``torch.FloatTensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
573
574
575
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
576
577
578

    Examples::

wangfei's avatar
wangfei committed
579
580
581
582
583
        tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
        model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states, mems = outputs[:2]
584

thomwolf's avatar
thomwolf committed
585
    """
586
587
    def __init__(self, config):
        super(TransfoXLModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
588
589
590
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

591
592
593
594
595
596
597
        self.n_token = config.n_token

        self.d_embed = config.d_embed
        self.d_model = config.d_model
        self.n_head = config.n_head
        self.d_head = config.d_head

598
        self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
599
                                          div_val=config.div_val)
thomwolf's avatar
thomwolf committed
600

601
        self.drop = nn.Dropout(config.dropout)
thomwolf's avatar
thomwolf committed
602

603
604
605
606
607
608
609
610
611
612
        self.n_layer = config.n_layer

        self.tgt_len = config.tgt_len
        self.mem_len = config.mem_len
        self.ext_len = config.ext_len
        self.max_klen = config.tgt_len + config.ext_len + config.mem_len

        self.attn_type = config.attn_type

        if not config.untie_r:
thomwolf's avatar
thomwolf committed
613
614
            self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
            self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
thomwolf's avatar
thomwolf committed
615

thomwolf's avatar
thomwolf committed
616
        self.layers = nn.ModuleList()
617
618
        if config.attn_type == 0: # the default attention
            for i in range(config.n_layer):
thomwolf's avatar
thomwolf committed
619
620
                self.layers.append(
                    RelPartialLearnableDecoderLayer(
621
622
623
624
                        config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
                        tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
                        dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
                        r_w_bias=None if config.untie_r else self.r_w_bias,
thomwolf's avatar
thomwolf committed
625
                        r_r_bias=None if config.untie_r else self.r_r_bias,
626
627
                        output_attentions=self.output_attentions,
                        layer_norm_epsilon=config.layer_norm_epsilon)
thomwolf's avatar
thomwolf committed
628
                )
629
630
        else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
            raise NotImplementedError  # Removed them to avoid maintaining dead code
thomwolf's avatar
thomwolf committed
631

632
633
        self.same_length = config.same_length
        self.clamp_len = config.clamp_len
thomwolf's avatar
thomwolf committed
634
635
636

        if self.attn_type == 0: # default attention
            self.pos_emb = PositionalEmbedding(self.d_model)
637
638
        else: # learnable embeddings and absolute embeddings
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
thomwolf's avatar
thomwolf committed
639

640
        self.init_weights()
thomwolf's avatar
thomwolf committed
641

thomwolf's avatar
thomwolf committed
642
    def _resize_token_embeddings(self, new_num_tokens):
thomwolf's avatar
thomwolf committed
643
        return self.word_emb
thomwolf's avatar
thomwolf committed
644

thomwolf's avatar
thomwolf committed
645
646
647
    def backward_compatible(self):
        self.sample_softmax = -1

thomwolf's avatar
thomwolf committed
648
649
650
651
652
    def reset_length(self, tgt_len, ext_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len

thomwolf's avatar
thomwolf committed
653
654
655
656
    def _prune_heads(self, heads):
        logger.info("Head pruning is not implemented for Transformer-XL model")
        pass

657
    def init_mems(self, data):
thomwolf's avatar
thomwolf committed
658
659
660
        if self.mem_len > 0:
            mems = []
            param = next(self.parameters())
661
            for i in range(self.n_layer):
662
663
                empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model,
                                    dtype=param.dtype, device=param.device)
thomwolf's avatar
thomwolf committed
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
                mems.append(empty)

            return mems
        else:
            return None

    def _update_mems(self, hids, mems, qlen, mlen):
        # does not deal with None
        if mems is None: return None

        # mems is not None
        assert len(hids) == len(mems), 'len(hids) != len(mems)'

        # There are `mlen + qlen` steps that can be cached into mems
        # For the next step, the last `ext_len` of the `qlen` tokens
        # will be used as the extended context. Hence, we only cache
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
        # to `mlen + qlen - self.ext_len`.
        with torch.no_grad():
            new_mems = []
            end_idx = mlen + max(0, qlen - 0 - self.ext_len)
            beg_idx = max(0, end_idx - self.mem_len)
            for i in range(len(hids)):

                cat = torch.cat([mems[i], hids[i]], dim=0)
                new_mems.append(cat[beg_idx:end_idx].detach())

        return new_mems

693
694
695
696
697
698
699
700
701
    def forward(self, input_ids, mems=None, head_mask=None):
        # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
        # so we transpose here from shape [bsz, len] to shape [len, bsz]
        input_ids = input_ids.transpose(0, 1).contiguous()

        if mems is None:
            mems = self.init_mems(input_ids)

        qlen, bsz = input_ids.size()
thomwolf's avatar
thomwolf committed
702

thomwolf's avatar
thomwolf committed
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
                head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layer

718
        word_emb = self.word_emb(input_ids)
thomwolf's avatar
thomwolf committed
719
720
721
722

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen
        if self.same_length:
thomwolf's avatar
thomwolf committed
723
            all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
thomwolf's avatar
thomwolf committed
724
725
726
727
728
729
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (torch.triu(all_ones, 1+mlen)
730
                    + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
thomwolf's avatar
thomwolf committed
731
732
        else:
            dec_attn_mask = torch.triu(
thomwolf's avatar
thomwolf committed
733
                word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None]
thomwolf's avatar
thomwolf committed
734
735

        hids = []
thomwolf's avatar
thomwolf committed
736
        attentions = []
thomwolf's avatar
thomwolf committed
737
        if self.attn_type == 0: # default
738
            pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
thomwolf's avatar
thomwolf committed
739
740
741
742
743
744
745
746
747
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb)
            pos_emb = self.drop(pos_emb)

            for i, layer in enumerate(self.layers):
748
                hids.append(core_out)
thomwolf's avatar
thomwolf committed
749
                mems_i = None if mems is None else mems[i]
thomwolf's avatar
thomwolf committed
750
751
752
753
754
                layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask,
                                      mems=mems_i, head_mask=head_mask[i])
                core_out = layer_outputs[0]
                if self.output_attentions:
                    attentions.append(layer_outputs[1])
755
756
        else: # learnable embeddings and absolute embeddings
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
thomwolf's avatar
thomwolf committed
757
758
759
760
761

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

thomwolf's avatar
thomwolf committed
762
763
764
765
766
767
768
769
770
771
772
        # We transpose back here to shape [bsz, len, hidden_dim]
        outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
        if self.output_hidden_states:
            # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
            hids.append(core_out)
            hids = list(t.transpose(0, 1).contiguous() for t in hids)
            outputs.append(hids)
        if self.output_attentions:
            # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
            attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
            outputs.append(attentions)
773

thomwolf's avatar
thomwolf committed
774
        return outputs  # last hidden state, new_mems, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
775
776


thomwolf's avatar
thomwolf committed
777
778
779
@add_start_docstrings("""The Transformer-XL Model with a language modeling head on top
    (adaptive softmax with weights tied to the adaptive input embeddings)""",
    TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
thomwolf's avatar
thomwolf committed
780
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
thomwolf's avatar
thomwolf committed
781
782
783
784
785
786
787
788
789
790
791
792
793
794
    r"""
        **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
            Labels for language modeling.
            Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
            Indices are selected in ``[-1, 0, ..., config.vocab_size]``
            All labels set to ``-1`` are ignored (masked), the loss is only
            computed for labels in ``[0, ..., config.vocab_size]``

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Language modeling loss.
        **prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
            We don't output them when the loss is computed to speedup adaptive softmax decoding.
thomwolf's avatar
thomwolf committed
795
        **mems**:
thomwolf's avatar
thomwolf committed
796
797
798
799
800
801
802
            list of ``torch.FloatTensor`` (one for each layer):
            that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
            (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
thomwolf's avatar
thomwolf committed
803
804
805
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
thomwolf's avatar
thomwolf committed
806
807
808

    Examples::

wangfei's avatar
wangfei committed
809
810
811
812
813
        tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
        model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        prediction_scores, mems = outputs[:2]
thomwolf's avatar
thomwolf committed
814
815
816
817
818
819
820
821
822
823
824
825

    """
    def __init__(self, config):
        super(TransfoXLLMHeadModel, self).__init__(config)
        self.transformer = TransfoXLModel(config)
        self.sample_softmax = config.sample_softmax
        # use sampled softmax
        if config.sample_softmax > 0:
            self.out_layer = nn.Linear(config.d_model, config.n_token)
            self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
        # use adaptive softmax (including standard softmax)
        else:
826
            self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
thomwolf's avatar
thomwolf committed
827
                                                    config.cutoffs, div_val=config.div_val)
828
        self.init_weights()
thomwolf's avatar
thomwolf committed
829
830
831
        self.tie_weights()

    def tie_weights(self):
832
833
834
        """
        Run this to be sure output and input (adaptive) softmax weights are tied
        """
thomwolf's avatar
thomwolf committed
835
836
837
838
839
840
841
842
        # sampled softmax
        if self.sample_softmax > 0:
            if self.config.tie_weight:
                self.out_layer.weight = self.transformer.word_emb.weight
        # adaptive softmax (including standard softmax)
        else:
            if self.config.tie_weight:
                for i in range(len(self.crit.out_layers)):
thomwolf's avatar
thomwolf committed
843
844
                    self._tie_or_clone_weights(self.crit.out_layers[i],
                                               self.transformer.word_emb.emb_layers[i])
thomwolf's avatar
thomwolf committed
845
846
847
            if self.config.tie_projs:
                for i, tie_proj in enumerate(self.config.tie_projs):
                    if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
thomwolf's avatar
thomwolf committed
848
849
850
851
                        if self.config.torchscript:
                            self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
                        else:
                            self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
thomwolf's avatar
thomwolf committed
852
                    elif tie_proj and self.config.div_val != 1:
thomwolf's avatar
thomwolf committed
853
854
855
856
                        if self.config.torchscript:
                            self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
                        else:
                            self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
thomwolf's avatar
thomwolf committed
857
858
859
860
861
862
863

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.transformer.reset_length(tgt_len, ext_len, mem_len)

    def init_mems(self, data):
        return self.transformer.init_mems(data)

864
    def forward(self, input_ids, mems=None, head_mask=None, labels=None):
865
866
        bsz = input_ids.size(0)
        tgt_len = input_ids.size(1)
thomwolf's avatar
thomwolf committed
867

thomwolf's avatar
thomwolf committed
868
        transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask)
thomwolf's avatar
thomwolf committed
869

thomwolf's avatar
thomwolf committed
870
        last_hidden = transformer_outputs[0]
871
        pred_hid = last_hidden[:, -tgt_len:]
thomwolf's avatar
thomwolf committed
872
        outputs = transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
873
        if self.sample_softmax > 0 and self.training:
thomwolf's avatar
thomwolf committed
874
            assert self.config.tie_weight
thomwolf's avatar
thomwolf committed
875
            logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
876
            softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
thomwolf's avatar
thomwolf committed
877
878
879
880
            outputs = [softmax_output] + outputs
            if labels is not None:
                # TODO: This is not implemented
                raise NotImplementedError
thomwolf's avatar
thomwolf committed
881
        else:
thomwolf's avatar
thomwolf committed
882
883
            softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
            if labels is None:
884
                softmax_output = softmax_output.view(bsz, tgt_len, -1)
thomwolf's avatar
thomwolf committed
885
                outputs = [softmax_output] + outputs
thomwolf's avatar
thomwolf committed
886
            else:
887
                softmax_output = softmax_output.view(bsz, tgt_len)
thomwolf's avatar
thomwolf committed
888
                outputs = [softmax_output, None] + outputs
thomwolf's avatar
thomwolf committed
889

thomwolf's avatar
thomwolf committed
890
        return outputs  # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)