mem_transformer.py 32.9 KB
Newer Older
Zhilin Yang's avatar
init  
Zhilin Yang committed
1
2
3
4
5
6
7
8
9
10
11
import sys
import math
import functools

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.append('utils')
12
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax, Projection
Zhilin Yang's avatar
init  
Zhilin Yang committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from log_uniform_sampler import LogUniformSampler, sample_logits

class PositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()

        self.demb = demb

        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, pos_seq, bsz=None):
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

        if bsz is not None:
            return pos_emb[:,None,:].expand(-1, bsz, -1)
        else:
            return pos_emb[:,None,:]

33

Zhilin Yang's avatar
init  
Zhilin Yang committed
34
class PositionwiseFF(nn.Module):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
35
    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
Zhilin Yang's avatar
init  
Zhilin Yang committed
36
37
38
39
40
41
        super(PositionwiseFF, self).__init__()

        self.d_model = d_model
        self.d_inner = d_inner
        self.dropout = dropout

Jiezhong Qiu's avatar
Jiezhong Qiu committed
42
43
        self.CoreNet = nn.Sequential(
            nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
Zhilin Yang's avatar
init  
Zhilin Yang committed
44
45
46
47
48
49
50
51
52
53
54
55
            nn.Dropout(dropout),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout),
        )

        self.layer_norm = nn.LayerNorm(d_model)

        self.pre_lnorm = pre_lnorm

    def forward(self, inp):
        if self.pre_lnorm:
            ##### layer normalization + positionwise feed-forward
Jiezhong Qiu's avatar
Jiezhong Qiu committed
56
            core_out = self.CoreNet(self.layer_norm(inp))
Zhilin Yang's avatar
init  
Zhilin Yang committed
57
58
59
60
61

            ##### residual connection
            output = core_out + inp
        else:
            ##### positionwise feed-forward
Jiezhong Qiu's avatar
Jiezhong Qiu committed
62
            core_out = self.CoreNet(inp)
Zhilin Yang's avatar
init  
Zhilin Yang committed
63
64
65
66

            ##### residual connection + layer normalization
            output = self.layer_norm(inp + core_out)

Jiezhong Qiu's avatar
Jiezhong Qiu committed
67
        return output
Jiezhong Qiu's avatar
Jiezhong Qiu committed
68

Zhilin Yang's avatar
init  
Zhilin Yang committed
69
class MultiHeadAttn(nn.Module):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
70
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
Zhilin Yang's avatar
init  
Zhilin Yang committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
                 pre_lnorm=False):
        super(MultiHeadAttn, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
        self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

        self.layer_norm = nn.LayerNorm(d_model)

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

    def forward(self, h, attn_mask=None, mems=None):
        ##### multihead attention
        # [hlen x bsz x n_head x d_head]

        if mems is not None:
            c = torch.cat([mems, h], 0)
        else:
            c = h

        if self.pre_lnorm:
            ##### layer normalization
            c = self.layer_norm(c)

        head_q = self.q_net(h)
        head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)

        head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
        head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
        head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)

        # [qlen x klen x bsz x n_head]
        attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
        attn_score.mul_(self.scale)
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
117
                attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
Zhilin Yang's avatar
init  
Zhilin Yang committed
118
            elif attn_mask.dim() == 3:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
119
                attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
Zhilin Yang's avatar
init  
Zhilin Yang committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

        # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = h + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(h + attn_out)

        return output

class RelMultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
145
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
146
                 moe=False, moe_num_expert=64, moe_top_k=2):
Zhilin Yang's avatar
init  
Zhilin Yang committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        super(RelMultiHeadAttn, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout

        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

        self.layer_norm = nn.LayerNorm(d_model)

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm

    def _parallelogram_mask(self, h, w, left=False):
        mask = torch.ones((h, w)).byte()
        m = min(h, w)
        mask[:m,:m] = torch.triu(mask[:m,:m])
        mask[-m:,-m:] = torch.tril(mask[-m:,-m:])

        if left:
            return mask
        else:
            return mask.flip(0)

    def _shift(self, x, qlen, klen, mask, left=False):
        if qlen > 1:
            zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
                                    device=x.device, dtype=x.dtype)
        else:
            zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)

        if left:
            mask = mask.flip(1)
            x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
        else:
            x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)

        x = x_padded.masked_select(mask[:,:,None,None]) \
                    .view(qlen, klen, x.size(2), x.size(3))

        return x

    def _rel_shift(self, x, zero_triu=False):
        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                               device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=1)

        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

        x = x_padded[1:].view_as(x)

        if zero_triu:
            ones = torch.ones((x.size(0), x.size(1)))
            x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]

        return x

    def forward(self, w, r, attn_mask=None, mems=None):
        raise NotImplementedError

class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
    def __init__(self, *args, **kwargs):
        super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)

        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)

    def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head

        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # qlen x n_head x d_head

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias                                         # qlen x bsz x n_head x d_head
        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head

        rr_head_q = w_head_q + r_r_bias
        BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x klen x bsz x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score = attn_score.float().masked_fill(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
265
                    attn_mask[None,:,:,None].bool(), -float('inf')).type_as(attn_score)
Zhilin Yang's avatar
init  
Zhilin Yang committed
266
267
            elif attn_mask.dim() == 3:
                attn_score = attn_score.float().masked_fill(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
268
                    attn_mask[:,:,:,None].bool(), -float('inf')).type_as(attn_score)
Zhilin Yang's avatar
init  
Zhilin Yang committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = w + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(w + attn_out)

        return output

class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
    def __init__(self, *args, **kwargs):
        super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)

    def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
        # r_emb: [klen, n_head, d_head], used for term B
        # r_w_bias: [n_head, d_head], used for term C
        # r_bias: [klen, n_head], used for term D

        qlen, bsz = w.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

            w_head_q = w_head_q[-qlen:]
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)

        if klen > r_emb.size(0):
            r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
            r_emb = torch.cat([r_emb_pad, r_emb], 0)
            r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
            r_bias = torch.cat([r_bias_pad, r_bias], 0)
        else:
            r_emb = r_emb[-klen:]
            r_bias = r_bias[-klen:]

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias[None]                                   # qlen x bsz x n_head x d_head

        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
        B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb))                  # qlen x klen x bsz x n_head
        D_ = r_bias[None, :, None]                                              # 1    x klen x 1   x n_head
        BD = self._rel_shift(B_ + D_)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
351
                attn_score.masked_fill_(attn_mask[None,:,:,None].bool(), -float('inf'))
Zhilin Yang's avatar
init  
Zhilin Yang committed
352
            elif attn_mask.dim() == 3:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
353
                attn_score.masked_fill_(attn_mask[:,:,:,None].bool(), -float('inf'))
Zhilin Yang's avatar
init  
Zhilin Yang committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = w + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(w + attn_out)

        return output

Rick Ho's avatar
Rick Ho committed
379
380
from fmoe import FMoETransformerMLP
class CustomizedMoEPositionwiseFF(FMoETransformerMLP):
381
    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, moe_num_expert=64, moe_top_k=2):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
382
        activation = nn.Sequential(
Jiezhong Qiu's avatar
fix bug  
Jiezhong Qiu committed
383
384
            nn.ReLU(),
            nn.Dropout(dropout)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
385
        )
386
        super().__init__(num_expert=moe_num_expert, d_model=d_model, d_hidden=d_inner, top_k=moe_top_k,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
387
                activation=activation)
Rick Ho's avatar
Rick Ho committed
388

Jiezhong Qiu's avatar
Jiezhong Qiu committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        self.pre_lnorm = pre_lnorm
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inp):
        if self.pre_lnorm:
            ##### layer normalization + positionwise feed-forward
            core_out = super().forward(self.layer_norm(inp))
            core_out = self.dropout(core_out)

            ##### residual connection
            output = core_out + inp
        else:
            ##### positionwise feed-forward
            core_out = super().forward(inp)
            core_out = self.dropout(core_out)

            ##### residual connection + layer normalization
            output = self.layer_norm(inp + core_out)

        return output
Rick Ho's avatar
Rick Ho committed
410

Zhilin Yang's avatar
init  
Zhilin Yang committed
411
412
413
414
415
class DecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
        super(DecoderLayer, self).__init__()

        self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
416
417
418
419
420
421
422
423
        if kwargs.get('moe') is False:
            self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 
                                        pre_lnorm=kwargs.get('pre_lnorm'))
        else:
            self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
                                        pre_lnorm=kwargs.get('pre_lnorm'), 
                                        moe_num_expert=kwargs.get('moe_num_expert'),
                                        moe_top_k=kwargs.get('moe_top_k'))
Zhilin Yang's avatar
init  
Zhilin Yang committed
424
425
426
427
428

    def forward(self, dec_inp, dec_attn_mask=None, mems=None):

        output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
                               mems=mems)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
429
        output = self.pos_ff(output)
Zhilin Yang's avatar
init  
Zhilin Yang committed
430

Jiezhong Qiu's avatar
Jiezhong Qiu committed
431
        return output
Zhilin Yang's avatar
init  
Zhilin Yang committed
432
433
434
435
436
437
438
439

class RelLearnableDecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
                 **kwargs):
        super(RelLearnableDecoderLayer, self).__init__()

        self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
                                         **kwargs)
440
441
442
443
444
445
446
447
448

        if kwargs.get('moe') is False:
            self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 
                                        pre_lnorm=kwargs.get('pre_lnorm'))
        else:
            self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
                                        pre_lnorm=kwargs.get('pre_lnorm'),
                                        moe_num_expert=kwargs.get('moe_num_expert'),
                                        moe_top_k=kwargs.get('moe_top_k'))
Zhilin Yang's avatar
init  
Zhilin Yang committed
449
450
451
452
453
454

    def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):

        output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
                               attn_mask=dec_attn_mask,
                               mems=mems)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
455
        output = self.pos_ff(output)
Zhilin Yang's avatar
init  
Zhilin Yang committed
456

Jiezhong Qiu's avatar
Jiezhong Qiu committed
457
        return output
Zhilin Yang's avatar
init  
Zhilin Yang committed
458
459
460
461
462
463
464
465

class RelPartialLearnableDecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout,
                 **kwargs):
        super(RelPartialLearnableDecoderLayer, self).__init__()

        self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
                            d_head, dropout, **kwargs)
466
467
468
469
470
471
472
473
474

        if kwargs.get('moe') is False:
            self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, 
                                        pre_lnorm=kwargs.get('pre_lnorm'))
        else:
            self.pos_ff = CustomizedMoEPositionwiseFF(d_model, d_inner, dropout,
                                        pre_lnorm=kwargs.get('pre_lnorm'),
                                        moe_num_expert=kwargs.get('moe_num_expert'),
                                        moe_top_k=kwargs.get('moe_top_k'))
Zhilin Yang's avatar
init  
Zhilin Yang committed
475
476
477
478
479
480

    def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):

        output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
                               attn_mask=dec_attn_mask,
                               mems=mems)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
481
        output = self.pos_ff(output)
Zhilin Yang's avatar
init  
Zhilin Yang committed
482

Jiezhong Qiu's avatar
Jiezhong Qiu committed
483
        return output
Zhilin Yang's avatar
init  
Zhilin Yang committed
484

485

Zhilin Yang's avatar
init  
Zhilin Yang committed
486
class AdaptiveEmbedding(nn.Module):
Jiezhong Qiu's avatar
Jiezhong Qiu committed
487
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
Zhilin Yang's avatar
init  
Zhilin Yang committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
                 sample_softmax=False):
        super(AdaptiveEmbedding, self).__init__()

        self.n_token = n_token
        self.d_embed = d_embed

        self.cutoffs = cutoffs + [n_token]
        self.div_val = div_val
        self.d_proj = d_proj

        self.emb_scale = d_proj ** 0.5

        self.cutoff_ends = [0] + self.cutoffs

        self.emb_layers = nn.ModuleList()
503
504
        self.emb_projs = nn.ModuleList()

Zhilin Yang's avatar
init  
Zhilin Yang committed
505
506
507
508
509
        if div_val == 1:
            self.emb_layers.append(
                nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
            )
            if d_proj != d_embed:
510
                self.emb_projs.append(Projection(d_proj, d_embed))
Zhilin Yang's avatar
init  
Zhilin Yang committed
511
512
513
514
515
        else:
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
                d_emb_i = d_embed // (div_val ** i)
                self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
516
                self.emb_projs.append(Projection(d_proj, d_emb_i))
Zhilin Yang's avatar
init  
Zhilin Yang committed
517
518
519
520
521

    def forward(self, inp):
        if self.div_val == 1:
            embed = self.emb_layers[0](inp)
            if self.d_proj != self.d_embed:
522
                embed  = F.linear(embed, self.emb_projs[0].weight)
Zhilin Yang's avatar
init  
Zhilin Yang committed
523
524
525
        else:
            param = next(self.parameters())
            inp_flat = inp.view(-1)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
526
            emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
Zhilin Yang's avatar
init  
Zhilin Yang committed
527
528
529
530
531
532
533
534
535
536
537
538
                dtype=param.dtype, device=param.device)
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
                indices_i = mask_i.nonzero().squeeze()

                if indices_i.numel() == 0:
                    continue

                inp_i = inp_flat.index_select(0, indices_i) - l_idx
                emb_i = self.emb_layers[i](inp_i)
539
                emb_i = F.linear(emb_i, self.emb_projs[i].weight)
Zhilin Yang's avatar
init  
Zhilin Yang committed
540
541
542
543
544
545
546
547
548
549
550

                emb_flat.index_copy_(0, indices_i, emb_i)

            embed = emb_flat.view(*inp.size(), self.d_proj)

        embed.mul_(self.emb_scale)

        return embed

class MemTransformerLM(nn.Module):
    def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
551
                 dropout, dropatt, tie_weight=True, d_embed=None,
Zhilin Yang's avatar
init  
Zhilin Yang committed
552
                 div_val=1, tie_projs=[False], pre_lnorm=False,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
553
                 tgt_len=None, ext_len=None, mem_len=None,
Zhilin Yang's avatar
init  
Zhilin Yang committed
554
                 cutoffs=[], adapt_inp=False,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
555
                 same_length=False, attn_type=0, clamp_len=-1,
556
                 sample_softmax=-1, moe=False, moe_num_expert=64, moe_top_k=2):
Zhilin Yang's avatar
init  
Zhilin Yang committed
557
558
559
560
561
562
563
564
565
        super(MemTransformerLM, self).__init__()
        self.n_token = n_token

        d_embed = d_model if d_embed is None else d_embed
        self.d_embed = d_embed
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head

Jiezhong Qiu's avatar
Jiezhong Qiu committed
566
        self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
Zhilin Yang's avatar
init  
Zhilin Yang committed
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
                                          div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        self.attn_type = attn_type

        self.layers = nn.ModuleList()
        if attn_type == 0: # the default attention
            for i in range(n_layer):
                self.layers.append(
                    RelPartialLearnableDecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
587
                        dropatt=dropatt, pre_lnorm=pre_lnorm, 
588
                        moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
Zhilin Yang's avatar
init  
Zhilin Yang committed
589
590
591
592
593
594
595
                )
        elif attn_type == 1: # learnable embeddings
            for i in range(n_layer):
                self.layers.append(
                    RelLearnableDecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
                        tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
596
                        dropatt=dropatt, pre_lnorm=pre_lnorm,
597
                        moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
Zhilin Yang's avatar
init  
Zhilin Yang committed
598
599
600
601
602
603
                )
        elif attn_type in [2, 3]: # absolute embeddings
            for i in range(n_layer):
                self.layers.append(
                    DecoderLayer(
                        n_head, d_model, d_head, d_inner, dropout,
604
                        dropatt=dropatt, pre_lnorm=pre_lnorm,
605
                        moe=moe, moe_num_expert=moe_num_expert, moe_top_k=moe_top_k)
Zhilin Yang's avatar
init  
Zhilin Yang committed
606
607
608
609
610
611
612
613
614
615
616
617
618
                )

        self.sample_softmax = sample_softmax
        # use sampled softmax
        if sample_softmax > 0:
            self.out_layer = nn.Linear(d_model, n_token)
            if tie_weight:
                self.out_layer.weight = self.word_emb.weight
            self.tie_weight = tie_weight
            self.sampler = LogUniformSampler(n_token, sample_softmax)

        # use adaptive softmax (including standard softmax)
        else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
619
            self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
Zhilin Yang's avatar
init  
Zhilin Yang committed
620
621
622
623
624
625
626
627
628
                                                    cutoffs, div_val=div_val)

            if tie_weight:
                for i in range(len(self.crit.out_layers)):
                    self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight

            if tie_projs:
                for i, tie_proj in enumerate(tie_projs):
                    if tie_proj and div_val == 1 and d_model != d_embed:
629
                        self.crit.out_projs[i].weight = self.word_emb.emb_projs[0].weight
Zhilin Yang's avatar
init  
Zhilin Yang committed
630
                    elif tie_proj and div_val != 1:
631
                        self.crit.out_projs[i].weight = self.word_emb.emb_projs[i].weight
Zhilin Yang's avatar
init  
Zhilin Yang committed
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663

        self.same_length = same_length
        self.clamp_len = clamp_len

        self._create_params()

    def backward_compatible(self):
        self.sample_softmax = -1

    def _create_params(self):
        if self.attn_type == 0: # default attention
            self.pos_emb = PositionalEmbedding(self.d_model)
            self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
            self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
        elif self.attn_type == 1: # learnable
            self.r_emb = nn.Parameter(torch.Tensor(
                    self.n_layer, self.max_klen, self.n_head, self.d_head))
            self.r_w_bias = nn.Parameter(torch.Tensor(
                    self.n_layer, self.n_head, self.d_head))
            self.r_bias = nn.Parameter(torch.Tensor(
                    self.n_layer, self.max_klen, self.n_head))
        elif self.attn_type == 2: # absolute standard
            self.pos_emb = PositionalEmbedding(self.d_model)
        elif self.attn_type == 3: # absolute deeper SA
            self.r_emb = nn.Parameter(torch.Tensor(
                    self.n_layer, self.max_klen, self.n_head, self.d_head))

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len

664
    def init_mems(self, x):
Zhilin Yang's avatar
init  
Zhilin Yang committed
665
666
667
        if self.mem_len > 0:
            mems = []
            for i in range(self.n_layer+1):
668
                empty = torch.empty(0, dtype=x.dtype, device=x.device)
Zhilin Yang's avatar
init  
Zhilin Yang committed
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
                mems.append(empty)

            return mems
        else:
            return None

    def _update_mems(self, hids, mems, qlen, mlen):
        # does not deal with None
        if mems is None: return None

        # mems is not None
        assert len(hids) == len(mems), 'len(hids) != len(mems)'

        # There are `mlen + qlen` steps that can be cached into mems
        # For the next step, the last `ext_len` of the `qlen` tokens
        # will be used as the extended context. Hence, we only cache
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
        # to `mlen + qlen - self.ext_len`.
        with torch.no_grad():
            new_mems = []
            end_idx = mlen + max(0, qlen - 0 - self.ext_len)
            beg_idx = max(0, end_idx - self.mem_len)
            for i in range(len(hids)):

                cat = torch.cat([mems[i], hids[i]], dim=0)
                new_mems.append(cat[beg_idx:end_idx].detach())

        return new_mems

    def _forward(self, dec_inp, mems=None):
        qlen, bsz = dec_inp.size()

        word_emb = self.word_emb(dec_inp)

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen
        if self.same_length:
            all_ones = word_emb.new_ones(qlen, klen)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (torch.triu(all_ones, 1+mlen)
                    + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
        else:
            dec_attn_mask = torch.triu(
                word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]

        hids = []
        if self.attn_type == 0: # default
Jiezhong Qiu's avatar
Jiezhong Qiu committed
720
            pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
Zhilin Yang's avatar
init  
Zhilin Yang committed
721
722
723
724
725
726
727
728
729
730
731
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb)
            pos_emb = self.drop(pos_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
732
                core_out = layer(core_out, pos_emb, self.r_w_bias,
Zhilin Yang's avatar
init  
Zhilin Yang committed
733
734
735
736
737
738
739
740
741
742
743
744
745
                        self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 1: # learnable
            core_out = self.drop(word_emb)
            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                if self.clamp_len > 0:
                    r_emb = self.r_emb[i][-self.clamp_len :]
                    r_bias = self.r_bias[i][-self.clamp_len :]
                else:
                    r_emb, r_bias = self.r_emb[i], self.r_bias[i]

                mems_i = None if mems is None else mems[i]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
746
                core_out = layer(core_out, r_emb, self.r_w_bias[i],
Zhilin Yang's avatar
init  
Zhilin Yang committed
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
                        r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 2: # absolute
            pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
                                   dtype=word_emb.dtype)
            if self.clamp_len > 0:
                pos_seq.clamp_(max=self.clamp_len)
            pos_emb = self.pos_emb(pos_seq)

            core_out = self.drop(word_emb + pos_emb[-qlen:])

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and i == 0:
                    mems_i += pos_emb[:mlen]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
763
                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
Zhilin Yang's avatar
init  
Zhilin Yang committed
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
                                 mems=mems_i)
                hids.append(core_out)
        elif self.attn_type == 3:
            core_out = self.drop(word_emb)

            hids.append(core_out)
            for i, layer in enumerate(self.layers):
                mems_i = None if mems is None else mems[i]
                if mems_i is not None and mlen > 0:
                    cur_emb = self.r_emb[i][:-qlen]
                    cur_size = cur_emb.size(0)
                    if cur_size < mlen:
                        cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
                        cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
                    else:
                        cur_emb = cur_emb[-mlen:]
                    mems_i += cur_emb.view(mlen, 1, -1)
                core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)

Jiezhong Qiu's avatar
Jiezhong Qiu committed
783
                core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
Zhilin Yang's avatar
init  
Zhilin Yang committed
784
785
786
787
788
789
790
                                 mems=mems_i)
                hids.append(core_out)

        core_out = self.drop(core_out)

        new_mems = self._update_mems(hids, mems, mlen, qlen)

Jiezhong Qiu's avatar
Jiezhong Qiu committed
791
        return core_out, new_mems
Zhilin Yang's avatar
init  
Zhilin Yang committed
792
793
794
795
796
797

    def forward(self, data, target, *mems):
        # nn.DataParallel does not allow size(0) tensors to be broadcasted.
        # So, have to initialize size(0) mems inside the model forward.
        # Moreover, have to return new_mems to allow nn.DataParallel to piece
        # them together.
798
        if not mems: mems = self.init_mems(data)
Zhilin Yang's avatar
init  
Zhilin Yang committed
799
800

        tgt_len = target.size(0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
801
        hidden, new_mems = self._forward(data, mems=mems)
Zhilin Yang's avatar
init  
Zhilin Yang committed
802
803
804
805
806
807
808
809

        pred_hid = hidden[-tgt_len:]
        if self.sample_softmax > 0 and self.training:
            assert self.tie_weight
            logit = sample_logits(self.word_emb,
                self.out_layer.bias, target, pred_hid, self.sampler)
            loss = -F.log_softmax(logit, -1)[:, :, 0]
        else:
Jiezhong Qiu's avatar
fix  
Jiezhong Qiu committed
810
            loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.contiguous().view(-1))
Zhilin Yang's avatar
init  
Zhilin Yang committed
811
812
813
            loss = loss.view(tgt_len, -1)

        if new_mems is None:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
814
            return [loss]
Zhilin Yang's avatar
init  
Zhilin Yang committed
815
        else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
816
            return [loss] + new_mems
Zhilin Yang's avatar
init  
Zhilin Yang committed
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='unit test')

    parser.add_argument('--n_layer', type=int, default=4, help='')
    parser.add_argument('--n_rel_layer', type=int, default=4, help='')
    parser.add_argument('--n_head', type=int, default=2, help='')
    parser.add_argument('--d_head', type=int, default=2, help='')
    parser.add_argument('--d_model', type=int, default=200, help='')
    parser.add_argument('--d_embed', type=int, default=200, help='')
    parser.add_argument('--d_inner', type=int, default=200, help='')
    parser.add_argument('--dropout', type=float, default=0.0, help='')
    parser.add_argument('--cuda', action='store_true', help='')
    parser.add_argument('--seed', type=int, default=1111, help='')
    parser.add_argument('--multi_gpu', action='store_true', help='')

    args = parser.parse_args()

    device = torch.device("cuda" if args.cuda else "cpu")

    B = 4
    tgt_len, mem_len, ext_len = 36, 36, 0
    data_len = tgt_len * 20
    args.n_token = 10000

    import data_utils

    data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)
    diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)

    cutoffs = [args.n_token // 2]
    tie_projs = [False] + [True] * len(cutoffs)

    for div_val in [1, 2]:
        for d_embed in [200, 100]:
            model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
                            args.d_model, args.d_head, args.d_inner, args.dropout,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
856
857
                            dropatt=args.dropout, tie_weight=True,
                            d_embed=d_embed, div_val=div_val,
Zhilin Yang's avatar
init  
Zhilin Yang committed
858
                            tie_projs=tie_projs, pre_lnorm=True,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
859
                            tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
Zhilin Yang's avatar
init  
Zhilin Yang committed
860
861
862
863
864
865
866
867
868
                            cutoffs=cutoffs, attn_type=0).to(device)

            print(sum(p.numel() for p in model.parameters()))

            mems = tuple()
            for idx, (inp, tgt, seqlen) in enumerate(diter):
                print('batch {}'.format(idx))
                out = model(inp, tgt, *mems)
                mems = out[1:]