modeling_transfo_xl.py 54.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
thomwolf's avatar
thomwolf committed
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model.
17
    Adapted from https://github.com/kimiyoung/transformer-xl.
thomwolf's avatar
thomwolf committed
18
19
20
    In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""

21
22
from __future__ import absolute_import, division, print_function, unicode_literals

thomwolf's avatar
thomwolf committed
23
24
25
26
27
import os
import json
import math
import logging
import collections
thomwolf's avatar
thomwolf committed
28
29
import sys
from io import open
thomwolf's avatar
thomwolf committed
30
31
32

import torch
import torch.nn as nn
33
import torch.nn.functional as F
thomwolf's avatar
thomwolf committed
34
35
36
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter

thomwolf's avatar
thomwolf committed
37
from .modeling_bert import BertLayerNorm as LayerNorm
thomwolf's avatar
thomwolf committed
38
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
39
from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
thomwolf's avatar
thomwolf committed
40
41
42

logger = logging.getLogger(__name__)

43
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
44
45
    'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
}
46
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
thomwolf's avatar
thomwolf committed
47
    'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
thomwolf's avatar
thomwolf committed
48
}
49

50
51
52
53
54
def build_tf_to_pytorch_map(model, config):
    """ A map of modules from TF to PyTorch.
        This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
    """
    tf_to_pt_map = {}
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    if hasattr(model, 'transformer'):
        # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax
        tf_to_pt_map.update({
            "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
            "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
        for i, (out_l, proj_l, tie_proj) in enumerate(zip(
                                model.crit.out_layers,
                                model.crit.out_projs,
                                config.tie_projs)):
            layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
            if config.tie_weight:
                tf_to_pt_map.update({
                    layer_str + 'b': out_l.bias})
            else:
                raise NotImplementedError
                # I don't think this is implemented in the TF code
                tf_to_pt_map.update({
                    layer_str + 'lookup_table': out_l.weight,
                    layer_str + 'b': out_l.bias})
            if not tie_proj:
                tf_to_pt_map.update({
                    layer_str + 'proj': proj_l
                    })
        # Now load the rest of the transformer
        model = model.transformer

thomwolf's avatar
thomwolf committed
82
    # Embeddings
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
        layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
        tf_to_pt_map.update({
            layer_str + 'lookup_table': embed_l.weight,
            layer_str + 'proj_W': proj_l
            })

    # Transformer blocks
    for i, b in enumerate(model.layers):
        layer_str = "transformer/layer_%d/" % i
        tf_to_pt_map.update({
            layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
            layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
            layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
            layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
            layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
            layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
            layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
            layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
            layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
            layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
            layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
        })

    # Relative positioning biases
    if config.untie_r:
        r_r_list = []
        r_w_list = []
        for b in model.layers:
            r_r_list.append(b.dec_attn.r_r_bias)
            r_w_list.append(b.dec_attn.r_w_bias)
    else:
        r_r_list = [model.r_r_bias]
        r_w_list = [model.r_w_bias]
    tf_to_pt_map.update({
        'transformer/r_r_bias': r_r_list,
        'transformer/r_w_bias': r_w_list})
    return tf_to_pt_map

def load_tf_weights_in_transfo_xl(model, config, tf_path):
    """ Load tf checkpoints in a pytorch model
    """
125
126
127
    try:
        import numpy as np
        import tensorflow as tf
thomwolf's avatar
thomwolf committed
128
    except ImportError:
thomwolf's avatar
thomwolf committed
129
        logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
130
131
            "https://www.tensorflow.org/install/ for installation instructions.")
        raise
132
133
134
135
136
137
138
    # Build TF to PyTorch weights loading map
    tf_to_pt_map = build_tf_to_pytorch_map(model, config)

    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    tf_weights = {}
    for name, shape in init_vars:
thomwolf's avatar
thomwolf committed
139
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        array = tf.train.load_variable(tf_path, name)
        tf_weights[name] = array

    for name, pointer in tf_to_pt_map.items():
        assert name in tf_weights
        array = tf_weights[name]
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if 'kernel' in name or 'proj' in name:
            array = np.transpose(array)
        if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
            # Here we will split the TF weigths
            assert len(pointer) == array.shape[0]
            for i, p_i in enumerate(pointer):
                arr_i = array[i, ...]
                try:
                    assert p_i.shape == arr_i.shape
                except AssertionError as e:
                    e.args += (p_i.shape, arr_i.shape)
                    raise
thomwolf's avatar
thomwolf committed
160
                logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
161
162
163
164
165
166
167
                p_i.data = torch.from_numpy(arr_i)
        else:
            try:
                assert pointer.shape == array.shape
            except AssertionError as e:
                e.args += (pointer.shape, array.shape)
                raise
thomwolf's avatar
thomwolf committed
168
            logger.info("Initialize PyTorch weight {}".format(name))
169
170
171
172
173
            pointer.data = torch.from_numpy(array)
        tf_weights.pop(name, None)
        tf_weights.pop(name + '/Adam', None)
        tf_weights.pop(name + '/Adam_1', None)

thomwolf's avatar
thomwolf committed
174
    logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
175
176
177
    return model


178
class TransfoXLConfig(PretrainedConfig):
thomwolf's avatar
thomwolf committed
179
    """Configuration class to store the configuration of a `TransfoXLModel`.
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

        Args:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file.
            cutoffs: cutoffs for the adaptive softmax
            d_model: Dimensionality of the model's hidden states.
            d_embed: Dimensionality of the embeddings
            d_head: Dimensionality of the model's heads.
            div_val: divident value for adapative input and softmax
            pre_lnorm: apply LayerNorm to the input instead of the output
            d_inner: Inner dimension in FF
            n_layer: Number of hidden layers in the Transformer encoder.
            n_head: Number of attention heads for each attention layer in
                the Transformer encoder.
            tgt_len: number of tokens to predict
            ext_len: length of the extended context
            mem_len: length of the retained previous heads
            same_length: use the same attn length for all tokens
            proj_share_all_but_first: True to share all but first projs, False not to share.
            attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
            clamp_len: use the same pos embeddings after clamp_len
            sample_softmax: number of samples in sampled softmax
            adaptive: use adaptive softmax
            tie_weight: tie the word embedding and softmax weights
            dropout: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            dropatt: The dropout ratio for the attention probabilities.
            untie_r: untie relative position biases
            embd_pdrop: The dropout ratio for the embeddings.
            init: parameter initializer to use
            init_range: parameters initialized by U(-init_range, init_range).
            proj_init_std: parameters initialized by N(0, init_std)
            init_std: parameters initialized by N(0, init_std)
thomwolf's avatar
thomwolf committed
212
    """
213
    pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
214

thomwolf's avatar
thomwolf committed
215
216
217
    def __init__(self,
                 vocab_size_or_config_json_file=267735,
                 cutoffs=[20000, 40000, 200000],
thomwolf's avatar
thomwolf committed
218
219
220
221
222
223
                 d_model=1024,
                 d_embed=1024,
                 n_head=16,
                 d_head=64,
                 d_inner=4096,
                 div_val=4,
thomwolf's avatar
thomwolf committed
224
                 pre_lnorm=False,
thomwolf's avatar
thomwolf committed
225
                 n_layer=18,
226
                 tgt_len=128,
thomwolf's avatar
thomwolf committed
227
                 ext_len=0,
228
229
230
231
                 mem_len=1600,
                 clamp_len=1000,
                 same_length=True,
                 proj_share_all_but_first=True,
thomwolf's avatar
thomwolf committed
232
233
234
                 attn_type=0,
                 sample_softmax=-1,
                 adaptive=True,
thomwolf's avatar
thomwolf committed
235
                 tie_weight=True,
thomwolf's avatar
thomwolf committed
236
237
                 dropout=0.1,
                 dropatt=0.0,
thomwolf's avatar
thomwolf committed
238
                 untie_r=True,
thomwolf's avatar
thomwolf committed
239
240
241
                 init="normal",
                 init_range=0.01,
                 proj_init_std=0.01,
thomwolf's avatar
thomwolf committed
242
243
                 init_std=0.02,
                 **kwargs):
thomwolf's avatar
thomwolf committed
244
245
        """Constructs TransfoXLConfig.
        """
thomwolf's avatar
thomwolf committed
246
247
        super(TransfoXLConfig, self).__init__(**kwargs)

thomwolf's avatar
thomwolf committed
248
249
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
thomwolf's avatar
thomwolf committed
250
251
252
253
254
            with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
thomwolf's avatar
thomwolf committed
255
            self.n_token = vocab_size_or_config_json_file
thomwolf's avatar
thomwolf committed
256
257
            self.cutoffs = []
            self.cutoffs.extend(cutoffs)
thomwolf's avatar
thomwolf committed
258
            self.tie_weight = tie_weight
259
260
261
262
            if proj_share_all_but_first:
                self.tie_projs = [False] + [True] * len(self.cutoffs)
            else:
                self.tie_projs = [False] + [False] * len(self.cutoffs)
thomwolf's avatar
thomwolf committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
            self.d_model = d_model
            self.d_embed = d_embed
            self.d_head = d_head
            self.d_inner = d_inner
            self.div_val = div_val
            self.pre_lnorm = pre_lnorm
            self.n_layer = n_layer
            self.n_head = n_head
            self.tgt_len = tgt_len
            self.ext_len = ext_len
            self.mem_len = mem_len
            self.same_length = same_length
            self.attn_type = attn_type
            self.clamp_len = clamp_len
            self.sample_softmax = sample_softmax
            self.adaptive = adaptive
            self.dropout = dropout
            self.dropatt = dropatt
thomwolf's avatar
thomwolf committed
281
            self.untie_r = untie_r
thomwolf's avatar
thomwolf committed
282
283
284
285
286
287
288
289
            self.init = init
            self.init_range = init_range
            self.proj_init_std = proj_init_std
            self.init_std = init_std
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
                             "or the path to a pretrained model config file (str)")

thomwolf's avatar
thomwolf committed
290
291
292
293
294
295
296
297
298
299
300
    @property
    def hidden_size(self):
        return self.d_model

    @property
    def num_attention_heads(self):
        return self.n_head

    @property
    def num_hidden_layers(self):
        return self.n_layer
thomwolf's avatar
thomwolf committed
301

thomwolf's avatar
thomwolf committed
302

thomwolf's avatar
thomwolf committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class PositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()

        self.demb = demb

        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, pos_seq, bsz=None):
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

        if bsz is not None:
            return pos_emb[:,None,:].expand(-1, bsz, -1)
        else:
            return pos_emb[:,None,:]


thomwolf's avatar
thomwolf committed
322

thomwolf's avatar
thomwolf committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
class PositionwiseFF(nn.Module):
    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
        super(PositionwiseFF, self).__init__()

        self.d_model = d_model
        self.d_inner = d_inner
        self.dropout = dropout

        self.CoreNet = nn.Sequential(
            nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout),
        )

thomwolf's avatar
thomwolf committed
338
        self.layer_norm = LayerNorm(d_model)
thomwolf's avatar
thomwolf committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357

        self.pre_lnorm = pre_lnorm

    def forward(self, inp):
        if self.pre_lnorm:
            ##### layer normalization + positionwise feed-forward
            core_out = self.CoreNet(self.layer_norm(inp))

            ##### residual connection
            output = core_out + inp
        else:
            ##### positionwise feed-forward
            core_out = self.CoreNet(inp)

            ##### residual connection + layer normalization
            output = self.layer_norm(inp + core_out)

        return output

thomwolf's avatar
thomwolf committed
358
359


thomwolf's avatar
thomwolf committed
360
361
class MultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, 
thomwolf's avatar
thomwolf committed
362
                 pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
thomwolf's avatar
thomwolf committed
363
364
        super(MultiHeadAttn, self).__init__()

thomwolf's avatar
thomwolf committed
365
        self.output_attentions = output_attentions
thomwolf's avatar
thomwolf committed
366
367
368
369
370
371
372
373
374
375
376
377
        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
        self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

thomwolf's avatar
thomwolf committed
378
        self.layer_norm = LayerNorm(d_model)
thomwolf's avatar
thomwolf committed
379
380
381
382
383

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

thomwolf's avatar
thomwolf committed
384
385
386
387
388
389
390
        if r_r_bias is None or r_w_bias is None: # Biases are not shared
            self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
            self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
        else:
            self.r_r_bias = r_r_bias
            self.r_w_bias = r_w_bias

thomwolf's avatar
thomwolf committed
391
    def forward(self, h, attn_mask=None, mems=None, head_mask=None):
thomwolf's avatar
thomwolf committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        ##### multihead attention
        # [hlen x bsz x n_head x d_head]

        if mems is not None:
            c = torch.cat([mems, h], 0)
        else:
            c = h

        if self.pre_lnorm:
            ##### layer normalization
            c = self.layer_norm(c)

        head_q = self.q_net(h)
        head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)

        head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
        head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
        head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)

        # [qlen x klen x bsz x n_head]
        attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
        attn_score.mul_(self.scale)
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
            elif attn_mask.dim() == 3:
                attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

thomwolf's avatar
thomwolf committed
424
425
426
427
        # Mask heads if we want to
        if head_mask is not None:
            attn_prob = attn_prob * head_mask

thomwolf's avatar
thomwolf committed
428
429
430
431
432
433
434
435
436
437
438
        # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
thomwolf's avatar
thomwolf committed
439
            outputs = [h + attn_out]
thomwolf's avatar
thomwolf committed
440
441
        else:
            ##### residual connection + layer normalization
thomwolf's avatar
thomwolf committed
442
            outputs = [self.layer_norm(h + attn_out)]
thomwolf's avatar
thomwolf committed
443

thomwolf's avatar
thomwolf committed
444
445
446
447
        if self.output_attentions:
            outputs.append(attn_prob)

        return outputs
thomwolf's avatar
thomwolf committed
448
449
450

class RelMultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
thomwolf's avatar
thomwolf committed
451
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
thomwolf's avatar
thomwolf committed
452
                 r_r_bias=None, r_w_bias=None, output_attentions=False):
thomwolf's avatar
thomwolf committed
453
454
        super(RelMultiHeadAttn, self).__init__()

thomwolf's avatar
thomwolf committed
455
        self.output_attentions = output_attentions
thomwolf's avatar
thomwolf committed
456
457
458
459
460
461
462
463
464
465
466
        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

thomwolf's avatar
thomwolf committed
467
        self.layer_norm = LayerNorm(d_model)
thomwolf's avatar
thomwolf committed
468
469
470
471
472

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

thomwolf's avatar
thomwolf committed
473
474
475
476
477
478
479
        if r_r_bias is None or r_w_bias is None: # Biases are not shared
            self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
            self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
        else:
            self.r_r_bias = r_r_bias
            self.r_w_bias = r_w_bias

thomwolf's avatar
thomwolf committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    def _parallelogram_mask(self, h, w, left=False):
        mask = torch.ones((h, w)).byte()
        m = min(h, w)
        mask[:m,:m] = torch.triu(mask[:m,:m])
        mask[-m:,-m:] = torch.tril(mask[-m:,-m:])

        if left:
            return mask
        else:
            return mask.flip(0)

    def _shift(self, x, qlen, klen, mask, left=False):
        if qlen > 1:
            zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
                                    device=x.device, dtype=x.dtype)
        else:
            zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)

        if left:
            mask = mask.flip(1)
            x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
        else:
            x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)

        x = x_padded.masked_select(mask[:,:,None,None]) \
                    .view(qlen, klen, x.size(2), x.size(3))

        return x

    def _rel_shift(self, x, zero_triu=False):
thomwolf's avatar
thomwolf committed
510
511
        zero_pad_shape = (x.size(0), 1) + x.size()[2:]
        zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
thomwolf's avatar
thomwolf committed
512
513
        x_padded = torch.cat([zero_pad, x], dim=1)

thomwolf's avatar
thomwolf committed
514
515
        x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
        x_padded = x_padded.view(*x_padded_shape)
thomwolf's avatar
thomwolf committed
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533

        x = x_padded[1:].view_as(x)

        if zero_triu:
            ones = torch.ones((x.size(0), x.size(1)))
            x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]

        return x

    def forward(self, w, r, attn_mask=None, mems=None):
        raise NotImplementedError

class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
    def __init__(self, *args, **kwargs):
        super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)

        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)

thomwolf's avatar
thomwolf committed
534
    def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
thomwolf's avatar
thomwolf committed
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head

        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # qlen x n_head x d_head

        #### compute attention score
565
        rw_head_q = w_head_q + self.r_w_bias                                    # qlen x bsz x n_head x d_head
thomwolf's avatar
thomwolf committed
566
567
        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head

thomwolf's avatar
thomwolf committed
568
        rr_head_q = w_head_q + self.r_r_bias
thomwolf's avatar
thomwolf committed
569
570
571
572
573
574
575
576
577
578
579
        BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score = attn_score.float().masked_fill(
580
                    attn_mask[None,:,:,None], -1e30).type_as(attn_score)
thomwolf's avatar
thomwolf committed
581
582
            elif attn_mask.dim() == 3:
                attn_score = attn_score.float().masked_fill(
583
                    attn_mask[:,:,:,None], -1e30).type_as(attn_score)
thomwolf's avatar
thomwolf committed
584
585
586
587
588

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

thomwolf's avatar
thomwolf committed
589
590
591
592
        # Mask heads if we want to
        if head_mask is not None:
            attn_prob = attn_prob * head_mask

thomwolf's avatar
thomwolf committed
593
594
595
596
597
598
599
600
601
602
603
604
605
        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
thomwolf's avatar
thomwolf committed
606
            outputs = [w + attn_out]
thomwolf's avatar
thomwolf committed
607
608
        else:
            ##### residual connection + layer normalization
thomwolf's avatar
thomwolf committed
609
            outputs = [self.layer_norm(w + attn_out)]
thomwolf's avatar
thomwolf committed
610

thomwolf's avatar
thomwolf committed
611
612
613
614
        if self.output_attentions:
            outputs.append(attn_prob)

        return outputs
thomwolf's avatar
thomwolf committed
615
616
617
618
619

class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
    def __init__(self, *args, **kwargs):
        super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)

thomwolf's avatar
thomwolf committed
620
    def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None, head_mask=None):
thomwolf's avatar
thomwolf committed
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
        # r_emb: [klen, n_head, d_head], used for term B
        # r_w_bias: [n_head, d_head], used for term C
        # r_bias: [klen, n_head], used for term D

        qlen, bsz = w.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)

        if klen > r_emb.size(0):
            r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
            r_emb = torch.cat([r_emb_pad, r_emb], 0)
            r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
            r_bias = torch.cat([r_bias_pad, r_bias], 0)
        else:
            r_emb = r_emb[-klen:]
            r_bias = r_bias[-klen:]

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias[None]                                   # qlen x bsz x n_head x d_head

        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
        B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb))                  # qlen x klen x bsz x n_head
        D_ = r_bias[None, :, None]                                              # 1    x klen x 1   x n_head
        BD = self._rel_shift(B_ + D_)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
            elif attn_mask.dim() == 3:
                attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

thomwolf's avatar
thomwolf committed
681
682
683
        if head_mask is not None:
            attn_prob = attn_prob * head_mask

thomwolf's avatar
thomwolf committed
684
685
686
687
688
689
690
691
692
693
694
695
696
        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
thomwolf's avatar
thomwolf committed
697
            outputs = [w + attn_out]
thomwolf's avatar
thomwolf committed
698
699
        else:
            ##### residual connection + layer normalization
thomwolf's avatar
thomwolf committed
700
701
702
703
704
705
706
            outputs = [self.layer_norm(w + attn_out)]

        if self.output_attentions:
            outputs.append(attn_prob)

        return outputs

thomwolf's avatar
thomwolf committed
707
708
709
710
711
712
713
714
715
716


class DecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
        super(DecoderLayer, self).__init__()

        self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 
                                     pre_lnorm=kwargs.get('pre_lnorm'))

thomwolf's avatar
thomwolf committed
717
    def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
thomwolf's avatar
thomwolf committed
718

thomwolf's avatar
thomwolf committed
719
720
721
        attn_outputs = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
                               mems=mems, head_mask=head_mask)
        ff_output = self.pos_ff(attn_outputs[0])
thomwolf's avatar
thomwolf committed
722

thomwolf's avatar
thomwolf committed
723
724
725
        outputs = [ff_output] + attn_outputs[1:]

        return outputs
thomwolf's avatar
thomwolf committed
726
727
728
729
730
731
732
733
734
735
736

class RelLearnableDecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
                 **kwargs):
        super(RelLearnableDecoderLayer, self).__init__()

        self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
                                         **kwargs)
        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 
                                     pre_lnorm=kwargs.get('pre_lnorm'))

thomwolf's avatar
thomwolf committed
737
    def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
thomwolf's avatar
thomwolf committed
738

thomwolf's avatar
thomwolf committed
739
        attn_outputs = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
thomwolf's avatar
thomwolf committed
740
                               attn_mask=dec_attn_mask,
thomwolf's avatar
thomwolf committed
741
742
                               mems=mems, head_mask=head_mask)
        ff_output = self.pos_ff(attn_outputs[0])
thomwolf's avatar
thomwolf committed
743

thomwolf's avatar
thomwolf committed
744
745
746
        outputs = [ff_output] + attn_outputs[1:]

        return outputs
thomwolf's avatar
thomwolf committed
747
748
749
750
751
752
753
754
755
756
757

class RelPartialLearnableDecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
                 **kwargs):
        super(RelPartialLearnableDecoderLayer, self).__init__()

        self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
                            d_head, dropout, **kwargs)
        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 
                                     pre_lnorm=kwargs.get('pre_lnorm'))

thomwolf's avatar
thomwolf committed
758
    def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
thomwolf's avatar
thomwolf committed
759

thomwolf's avatar
thomwolf committed
760
        attn_outputs = self.dec_attn(dec_inp, r,
thomwolf's avatar
thomwolf committed
761
                               attn_mask=dec_attn_mask,
thomwolf's avatar
thomwolf committed
762
763
764
765
766
767
                               mems=mems, head_mask=head_mask)
        ff_output = self.pos_ff(attn_outputs[0])

        outputs = [ff_output] + attn_outputs[1:]

        return outputs
thomwolf's avatar
thomwolf committed
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826



class AdaptiveEmbedding(nn.Module):
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 
                 sample_softmax=False):
        super(AdaptiveEmbedding, self).__init__()

        self.n_token = n_token
        self.d_embed = d_embed

        self.cutoffs = cutoffs + [n_token]
        self.div_val = div_val
        self.d_proj = d_proj

        self.emb_scale = d_proj ** 0.5

        self.cutoff_ends = [0] + self.cutoffs

        self.emb_layers = nn.ModuleList()
        self.emb_projs = nn.ParameterList()
        if div_val == 1:
            self.emb_layers.append(
                nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
            )
            if d_proj != d_embed:
                self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
        else:
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
                d_emb_i = d_embed // (div_val ** i)
                self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
                self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))

    def forward(self, inp):
        if self.div_val == 1:
            embed = self.emb_layers[0](inp)
            if self.d_proj != self.d_embed:
                embed  = F.linear(embed, self.emb_projs[0])
        else:
            param = next(self.parameters())
            inp_flat = inp.view(-1)
            emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], 
                dtype=param.dtype, device=param.device)
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
                indices_i = mask_i.nonzero().squeeze()

                if indices_i.numel() == 0:
                    continue

                inp_i = inp_flat.index_select(0, indices_i) - l_idx
                emb_i = self.emb_layers[i](inp_i)
                emb_i = F.linear(emb_i, self.emb_projs[i])

                emb_flat.index_copy_(0, indices_i, emb_i)

thomwolf's avatar
thomwolf committed
827
828
            embed_shape = inp.size() + (self.d_proj,)
            embed = emb_flat.view(embed_shape)
thomwolf's avatar
thomwolf committed
829
830
831
832
833
834

        embed.mul_(self.emb_scale)

        return embed


835
class TransfoXLPreTrainedModel(PreTrainedModel):
836
837
838
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
839
    config_class = TransfoXLConfig
840
    pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
841
842
843
    load_tf_weights = load_tf_weights_in_transfo_xl
    base_model_prefix = "transformer"

844
845
846
    def __init__(self, *inputs, **kwargs):
        super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs)

847
    def _init_weight(self, weight):
848
849
850
851
        if self.config.init == 'uniform':
            nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
        elif self.config.init == 'normal':
            nn.init.normal_(weight, 0.0, self.config.init_std)
thomwolf's avatar
thomwolf committed
852

853
    def _init_bias(self, bias):
854
855
856
857
858
859
860
861
        nn.init.constant_(bias, 0.0)

    def init_weights(self, m):
        """ Initialize the weights.
        """
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            if hasattr(m, 'weight') and m.weight is not None:
862
                self._init_weight(m.weight)
863
            if hasattr(m, 'bias') and m.bias is not None:
864
                self._init_bias(m.bias)
865
866
867
868
869
870
871
        elif classname.find('AdaptiveEmbedding') != -1:
            if hasattr(m, 'emb_projs'):
                for i in range(len(m.emb_projs)):
                    if m.emb_projs[i] is not None:
                        nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
        elif classname.find('Embedding') != -1:
            if hasattr(m, 'weight'):
872
                self._init_weight(m.weight)
873
874
        elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
            if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
875
                self._init_weight(m.cluster_weight)
876
            if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
877
                self._init_bias(m.cluster_bias)
878
879
880
881
882
883
884
885
            if hasattr(m, 'out_projs'):
                for i in range(len(m.out_projs)):
                    if m.out_projs[i] is not None:
                        nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
        elif classname.find('LayerNorm') != -1:
            if hasattr(m, 'weight'):
                nn.init.normal_(m.weight, 1.0, self.config.init_std)
            if hasattr(m, 'bias') and m.bias is not None:
886
                self._init_bias(m.bias)
887
        else:
888
            if hasattr(m, 'r_emb'):
889
                self._init_weight(m.r_emb)
890
            if hasattr(m, 'r_w_bias'):
891
                self._init_weight(m.r_w_bias)
892
            if hasattr(m, 'r_r_bias'):
893
                self._init_weight(m.r_r_bias)
894
            if hasattr(m, 'r_bias'):
895
                self._init_bias(m.r_bias)
thomwolf's avatar
thomwolf committed
896

897
898
    def set_num_special_tokens(self, num_special_tokens):
        pass
thomwolf's avatar
thomwolf committed
899

900
901

class TransfoXLModel(TransfoXLPreTrainedModel):
thomwolf's avatar
thomwolf committed
902
903
    """Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").

904
    Transformer XL uses relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that:
thomwolf's avatar
thomwolf committed
905

906
907
908
909
910
        - you don't need to specify positioning embeddings indices.

        - the tokens in the vocabulary have to be sorted in decreasing frequency.

    Args:
thomwolf's avatar
thomwolf committed
911
912
        config: a TransfoXLConfig class instance with the configuration to build a new model

913
914
915
916
917

    Example::

        config = TransfoXLConfig()
        model = TransfoXLModel(config)
thomwolf's avatar
thomwolf committed
918
    """
919
920
    def __init__(self, config):
        super(TransfoXLModel, self).__init__(config)
thomwolf's avatar
thomwolf committed
921
922
923
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states

924
925
926
927
928
929
930
931
932
        self.n_token = config.n_token

        self.d_embed = config.d_embed
        self.d_model = config.d_model
        self.n_head = config.n_head
        self.d_head = config.d_head

        self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, 
                                          div_val=config.div_val)
thomwolf's avatar
thomwolf committed
933

934
        self.drop = nn.Dropout(config.dropout)
thomwolf's avatar
thomwolf committed
935

936
937
938
939
940
941
942
943
944
945
        self.n_layer = config.n_layer

        self.tgt_len = config.tgt_len
        self.mem_len = config.mem_len
        self.ext_len = config.ext_len
        self.max_klen = config.tgt_len + config.ext_len + config.mem_len

        self.attn_type = config.attn_type

        if not config.untie_r:
thomwolf's avatar
thomwolf committed
946
947
948
            self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
            self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))

thomwolf's avatar
thomwolf committed
949
        self.layers = nn.ModuleList()
950
951
        if config.attn_type == 0: # the default attention
            for i in range(config.n_layer):
thomwolf's avatar
thomwolf committed
952
953
                self.layers.append(
                    RelPartialLearnableDecoderLayer(
954
955
956
957
                        config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
                        tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
                        dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
                        r_w_bias=None if config.untie_r else self.r_w_bias,
thomwolf's avatar
thomwolf committed
958
959
                        r_r_bias=None if config.untie_r else self.r_r_bias,
                        output_attentions=self.output_attentions)
thomwolf's avatar
thomwolf committed
960
                )
961
962
        elif config.attn_type == 1: # learnable embeddings
            for i in range(config.n_layer):
thomwolf's avatar
thomwolf committed
963
964
                self.layers.append(
                    RelLearnableDecoderLayer(
965
966
967
968
                        config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
                        tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
                        dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
                        r_w_bias=None if config.untie_r else self.r_w_bias,
thomwolf's avatar
thomwolf committed
969
970
                        r_r_bias=None if config.untie_r else self.r_r_bias,
                        output_attentions=self.output_attentions)
thomwolf's avatar
thomwolf committed
971
                )
972
973
        elif config.attn_type in [2, 3]: # absolute embeddings
            for i in range(config.n_layer):
thomwolf's avatar
thomwolf committed
974
975
                self.layers.append(
                    DecoderLayer(
976
977
978
                        config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
                        dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
                        r_w_bias=None if config.untie_r else self.r_w_bias,
thomwolf's avatar
thomwolf committed
979
980
                        r_r_bias=None if config.untie_r else self.r_r_bias,
                        output_attentions=self.output_attentions)
thomwolf's avatar
thomwolf committed
981
982
                )

983
984
        self.same_length = config.same_length
        self.clamp_len = config.clamp_len
thomwolf's avatar
thomwolf committed
985
986
987
988
989
990
991
992
993
994
995
996
997

        if self.attn_type == 0: # default attention
            self.pos_emb = PositionalEmbedding(self.d_model)
        elif self.attn_type == 1: # learnable
            self.r_emb = nn.Parameter(torch.Tensor(
                    self.n_layer, self.max_klen, self.n_head, self.d_head))
            self.r_bias = nn.Parameter(torch.Tensor(
                    self.n_layer, self.max_klen, self.n_head))
        elif self.attn_type == 2: # absolute standard
            self.pos_emb = PositionalEmbedding(self.d_model)
        elif self.attn_type == 3: # absolute deeper SA
            self.r_emb = nn.Parameter(torch.Tensor(
                    self.n_layer, self.max_klen, self.n_head, self.d_head))
thomwolf's avatar
thomwolf committed
998

thomwolf's avatar
thomwolf committed
999
        self.apply(self.init_weights)
thomwolf's avatar
thomwolf committed
1000

thomwolf's avatar
thomwolf committed
1001
1002
1003
    def backward_compatible(self):
        self.sample_softmax = -1

thomwolf's avatar
thomwolf committed
1004
1005
1006
1007
1008
    def reset_length(self, tgt_len, ext_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len

thomwolf's avatar
thomwolf committed
1009
1010
1011
1012
    def _prune_heads(self, heads):
        logger.info("Head pruning is not implemented for Transformer-XL model")
        pass

1013
    def init_mems(self, data):
thomwolf's avatar
thomwolf committed
1014
1015
1016
        if self.mem_len > 0:
            mems = []
            param = next(self.parameters())
1017
            for i in range(self.n_layer):
1018
1019
                empty = torch.zeros(self.mem_len, data.size(1), self.config.d_model,
                                    dtype=param.dtype, device=param.device)
thomwolf's avatar
thomwolf committed
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
                mems.append(empty)

            return mems
        else:
            return None

    def _update_mems(self, hids, mems, qlen, mlen):
        # does not deal with None
        if mems is None: return None

        # mems is not None
        assert len(hids) == len(mems), 'len(hids) != len(mems)'

        # There are `mlen + qlen` steps that can be cached into mems
        # For the next step, the last `ext_len` of the `qlen` tokens
        # will be used as the extended context. Hence, we only cache
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
        # to `mlen + qlen - self.ext_len`.
        with torch.no_grad():
            new_mems = []
            end_idx = mlen + max(0, qlen - 0 - self.ext_len)
            beg_idx = max(0, end_idx - self.mem_len)
            for i in range(len(hids)):

                cat = torch.cat([mems[i], hids[i]], dim=0)
                new_mems.append(cat[beg_idx:end_idx].detach())

        return new_mems

thomwolf's avatar
thomwolf committed
1049
    def _forward(self, dec_inp, mems=None, head_mask=None):
thomwolf's avatar
thomwolf committed
1050
1051
        qlen, bsz = dec_inp.size()

thomwolf's avatar
thomwolf committed
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
                head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.n_layer

thomwolf's avatar
thomwolf committed
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
        word_emb = self.word_emb(dec_inp)

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen
        if self.same_length:
            all_ones = word_emb.new_ones(qlen, klen)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (torch.triu(all_ones, 1+mlen)
                    + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
        else:
            dec_attn_mask = torch.triu(
                word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]

        hids = []
thomwolf's avatar
thomwolf committed
1085
        attentions = []
thomwolf's avatar
thomwolf committed
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
        if self.attn_type == 0: # default
            pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, 
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb)
            pos_emb = self.drop(pos_emb)

            for i, layer in enumerate(self.layers):
1097
                hids.append(core_out)
thomwolf's avatar
thomwolf committed
1098
                mems_i = None if mems is None else mems[i]
thomwolf's avatar
thomwolf committed
1099
1100
1101
1102
1103
                layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask,
                                      mems=mems_i, head_mask=head_mask[i])
                core_out = layer_outputs[0]
                if self.output_attentions:
                    attentions.append(layer_outputs[1])
thomwolf's avatar
thomwolf committed
1104
1105
1106
        elif self.attn_type == 1: # learnable
            core_out = self.drop(word_emb)
            for i, layer in enumerate(self.layers):
1107
                hids.append(core_out)
thomwolf's avatar
thomwolf committed
1108
1109
1110
1111
1112
1113
1114
                if self.clamp_len > 0:
                    r_emb = self.r_emb[i][-self.clamp_len :]
                    r_bias = self.r_bias[i][-self.clamp_len :]
                else:
                    r_emb, r_bias = self.r_emb[i], self.r_bias[i]

                mems_i = None if mems is None else mems[i]
thomwolf's avatar
thomwolf committed
1115
1116
1117
1118
1119
1120
                layer_outputs = layer(core_out, r_emb, self.r_w_bias[i],
                                      r_bias, dec_attn_mask=dec_attn_mask,
                                      mems=mems_i, head_mask=head_mask[i])
                core_out = layer_outputs[0]
                if self.output_attentions:
                    attentions.append(layer_outputs[1])
thomwolf's avatar
thomwolf committed
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
        elif self.attn_type == 2: # absolute
            pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb + pos_emb[-qlen:])

            for i, layer in enumerate(self.layers):
1131
                hids.append(core_out)
thomwolf's avatar
thomwolf committed
1132
1133
1134
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and i == 0:
                    mems_i += pos_emb[:mlen]
thomwolf's avatar
thomwolf committed
1135
1136
1137
1138
1139
                layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
                                 mems=mems_i, head_mask=head_mask[i])
                core_out = layer_outputs[0]
                if self.output_attentions:
                    attentions.append(layer_outputs[1])
thomwolf's avatar
thomwolf committed
1140
1141
1142
1143
        elif self.attn_type == 3:
            core_out = self.drop(word_emb)

            for i, layer in enumerate(self.layers):
1144
                hids.append(core_out)
thomwolf's avatar
thomwolf committed
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and mlen > 0:
                    cur_emb = self.r_emb[i][:-qlen]
                    cur_size = cur_emb.size(0)
                    if cur_size < mlen:
                        cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
                        cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
                    else:
                        cur_emb = cur_emb[-mlen:]
                    mems_i += cur_emb.view(mlen, 1, -1)
                core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)

thomwolf's avatar
thomwolf committed
1157
1158
1159
1160
1161
                layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
                                      mems=mems_i, head_mask=head_mask[i])
                core_out = layer_outputs[0]
                if self.output_attentions:
                    attentions.append(layer_outputs[1])
thomwolf's avatar
thomwolf committed
1162
1163
1164
1165
1166

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

thomwolf's avatar
thomwolf committed
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
        # We transpose back here to shape [bsz, len, hidden_dim]
        outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
        if self.output_hidden_states:
            # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
            hids.append(core_out)
            hids = list(t.transpose(0, 1).contiguous() for t in hids)
            outputs.append(hids)
        if self.output_attentions:
            # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
            attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
            outputs.append(attentions)
        return outputs  # last hidden state, new_mems, (all hidden states), (all attentions)

    def forward(self, input_ids, mems=None, head_mask=None):
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
        """
        Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**

        Args:
            `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with the token indices selected in the range [0, self.config.n_token[
            `mems`: optional memory of hidden states from previous forward passes
                as a list (num layers) of hidden states at the entry of each layer
                each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
                Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`

        Returns:
            A tuple of ``(last_hidden_state, new_mems)``.

                ``last_hidden_state``: the encoded-hidden-states at the top of the model
                as a ``torch.FloatTensor`` of size [batch_size, sequence_length, self.config.d_model]

                ``new_mems``: list (num layers) of updated mem states at the entry of each layer
                each mem state is a ``torch.FloatTensor`` of size [self.config.mem_len, batch_size, self.config.d_model]
                Note that the first two dimensions are transposed in ``mems`` with regards to ``input_ids`` and
                ``labels``

        Example::

            # Already been converted into BPE token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]])

            last_hidden_state, new_mems = model(input_ids)
            # or
            last_hidden_state, new_mems = model.forward(input_ids)

            # Another time on input_ids_next using the memory:
            last_hidden_state, new_mems = model(input_ids_next, new_mems)
thomwolf's avatar
thomwolf committed
1215
        """
1216
1217
1218
1219
        # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
        # so we transpose here from shape [bsz, len] to shape [len, bsz]
        input_ids = input_ids.transpose(0, 1).contiguous()

thomwolf's avatar
thomwolf committed
1220
1221
        if mems is None:
            mems = self.init_mems(input_ids)
thomwolf's avatar
thomwolf committed
1222
        outputs = self._forward(input_ids, mems=mems, head_mask=head_mask)
1223

thomwolf's avatar
thomwolf committed
1224
        return outputs  # last hidden state, new_mems, (all hidden states), (all attentions)
thomwolf's avatar
thomwolf committed
1225
1226
1227
1228
1229


class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
    """Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context").

1230
    This model adds an (adaptive) softmax head on top of the ``TransfoXLModel``
thomwolf's avatar
thomwolf committed
1231

1232
    Transformer XL uses a relative positioning (with sinusoidal patterns) and adaptive softmax inputs which means that:
thomwolf's avatar
thomwolf committed
1233

1234
        - you don't need to specify positioning embeddings indices
thomwolf's avatar
thomwolf committed
1235

1236
        - the tokens in the vocabulary have to be sorted in decreasing frequency.
thomwolf's avatar
thomwolf committed
1237

1238
    Call ``self.tie_weights()`` if you update/load the weights of the transformer to keep the weights tied.
thomwolf's avatar
thomwolf committed
1239

1240
1241
    Args:
        config: a ``TransfoXLConfig`` class instance with the configuration to build a new model
thomwolf's avatar
thomwolf committed
1242
1243


1244
    Example::
thomwolf's avatar
thomwolf committed
1245

1246
1247
        config = TransfoXLConfig()
        model = TransfoXLModel(config)
thomwolf's avatar
thomwolf committed
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
    """
    def __init__(self, config):
        super(TransfoXLLMHeadModel, self).__init__(config)
        self.transformer = TransfoXLModel(config)
        self.sample_softmax = config.sample_softmax
        # use sampled softmax
        if config.sample_softmax > 0:
            self.out_layer = nn.Linear(config.d_model, config.n_token)
            self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
        # use adaptive softmax (including standard softmax)
        else:
            self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, 
                                                    config.cutoffs, div_val=config.div_val)
        self.apply(self.init_weights)
        self.tie_weights()

    def tie_weights(self):
1265
1266
1267
        """
        Run this to be sure output and input (adaptive) softmax weights are tied
        """
thomwolf's avatar
thomwolf committed
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
        # sampled softmax
        if self.sample_softmax > 0:
            if self.config.tie_weight:
                self.out_layer.weight = self.transformer.word_emb.weight
        # adaptive softmax (including standard softmax)
        else:
            if self.config.tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.transformer.word_emb.emb_layers[i].weight
            if self.config.tie_projs:
                for i, tie_proj in enumerate(self.config.tie_projs):
                    if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
                        self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
                    elif tie_proj and self.config.div_val != 1:
                        self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.transformer.reset_length(tgt_len, ext_len, mem_len)

    def init_mems(self, data):
        return self.transformer.init_mems(data)

thomwolf's avatar
thomwolf committed
1290
    def forward(self, input_ids, labels=None, mems=None, head_mask=None):
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
        """
        Performs a model forward pass. **Can be called by calling the class directly, once it has been instantiated.**

        Args:
            `input_ids`: a ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with the token indices selected in the range [0, self.config.n_token[
            `labels`: an optional ``torch.LongTensor`` of shape [batch_size, sequence_length]
                with the labels token indices selected in the range [0, self.config.n_token[
            `mems`: an optional memory of hidden states from previous forward passes
                as a list (num layers) of hidden states at the entry of each layer
                each hidden states has shape [self.config.mem_len, bsz, self.config.d_model]
                Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`

        Returns:
            A tuple of (last_hidden_state, new_mems)

                ``last_hidden_state``: output of the (adaptive) softmax. If ``labels`` is ``None``, it is the negative
                log likelihood of shape [batch_size, sequence_length]. Otherwise, it is the log probabilities of
                tokens of, shape [batch_size, sequence_length, n_tokens].

                ``new_mems``: list (num layers) of updated mem states at the entry of each layer
                each mem state is a ``torch.FloatTensor`` of size [self.config.mem_len, batch_size, self.config.d_model]
                Note that the first two dimensions are transposed in ``mems`` with regards to ``input_ids`` and
                ``labels``

        Example::

            # Already been converted into BPE token ids
            input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
            input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]])

            last_hidden_state, new_mems = model(input_ids)
            # or
            last_hidden_state, new_mems = model.forward(input_ids)

            # Another time on input_ids_next using the memory:
            last_hidden_state, new_mems = model(input_ids_next, mems=new_mems)
thomwolf's avatar
thomwolf committed
1328
        """
1329
1330
        bsz = input_ids.size(0)
        tgt_len = input_ids.size(1)
thomwolf's avatar
thomwolf committed
1331

thomwolf's avatar
thomwolf committed
1332
        transformer_outputs = self.transformer(input_ids, mems, head_mask)
thomwolf's avatar
thomwolf committed
1333

thomwolf's avatar
thomwolf committed
1334
        last_hidden = transformer_outputs[0]
1335
        pred_hid = last_hidden[:, -tgt_len:]
thomwolf's avatar
thomwolf committed
1336
        outputs = transformer_outputs[1:]
thomwolf's avatar
thomwolf committed
1337
        if self.sample_softmax > 0 and self.training:
thomwolf's avatar
thomwolf committed
1338
            assert self.config.tie_weight
thomwolf's avatar
thomwolf committed
1339
            logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
1340
            softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
thomwolf's avatar
thomwolf committed
1341
1342
1343
1344
            outputs = [softmax_output] + outputs
            if labels is not None:
                # TODO: This is not implemented
                raise NotImplementedError
thomwolf's avatar
thomwolf committed
1345
        else:
thomwolf's avatar
thomwolf committed
1346
1347
            softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
            if labels is None:
1348
                softmax_output = softmax_output.view(bsz, tgt_len, -1)
thomwolf's avatar
thomwolf committed
1349
                outputs = [softmax_output] + outputs
thomwolf's avatar
thomwolf committed
1350
            else:
1351
                softmax_output = softmax_output.view(bsz, tgt_len)
thomwolf's avatar
thomwolf committed
1352
                outputs = [softmax_output, None] + outputs
thomwolf's avatar
thomwolf committed
1353

thomwolf's avatar
thomwolf committed
1354
        return outputs  # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)