Commit 958c714e authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

add mem2mem attention

parent 44781ed2
......@@ -66,8 +66,91 @@ class PositionwiseFF(nn.Module):
return output
class ExtendedMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False):
super(MultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
def forward(self, h, attn_mask=None, mems=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
mem_len = mems.size(0)
else:
c = h
mem_len = 0
if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)
head_q = self.q_net(c)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score[mem_len:].masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score[mem_len:].masked_fill_(attn_mask[:,:,:,None], -float('inf'))
mem2other_attn = attn_mask.ones(mem_len, c.size(0))
mem2other_attn[:, :mem_len] = 0
attn_score[:mem_len].masked_fill_(mem2other_attn, -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
attn_vec = attn_vec[mem_len:]
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = h + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(h + attn_out)
return output
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False):
super(MultiHeadAttn, self).__init__()
......@@ -380,7 +463,7 @@ class DecoderLayer(nn.Module):
super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
......@@ -398,7 +481,7 @@ class RelLearnableDecoderLayer(nn.Module):
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
......@@ -417,7 +500,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
......@@ -431,7 +514,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
super(AdaptiveEmbedding, self).__init__()
......@@ -469,7 +552,7 @@ class AdaptiveEmbedding(nn.Module):
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
......@@ -494,11 +577,11 @@ class AdaptiveEmbedding(nn.Module):
class MemTransformerLM(nn.Module):
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
dropout, dropatt, tie_weight=True, d_embed=None,
dropout, dropatt, tie_weight=True, d_embed=None,
div_val=1, tie_projs=[False], pre_lnorm=False,
tgt_len=None, ext_len=None, mem_len=None,
tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False,
same_length=False, attn_type=0, clamp_len=-1,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1):
super(MemTransformerLM, self).__init__()
self.n_token = n_token
......@@ -509,7 +592,7 @@ class MemTransformerLM(nn.Module):
self.n_head = n_head
self.d_head = d_head
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
div_val=div_val)
self.drop = nn.Dropout(dropout)
......@@ -559,7 +642,7 @@ class MemTransformerLM(nn.Module):
# use adaptive softmax (including standard softmax)
else:
self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
cutoffs, div_val=div_val)
if tie_weight:
......@@ -661,7 +744,7 @@ class MemTransformerLM(nn.Module):
hids = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
......@@ -797,10 +880,10 @@ if __name__ == '__main__':
for d_embed in [200, 100]:
model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
args.d_model, args.d_head, args.d_inner, args.dropout,
dropatt=args.dropout, tie_weight=True,
d_embed=d_embed, div_val=div_val,
dropatt=args.dropout, tie_weight=True,
d_embed=d_embed, div_val=div_val,
tie_projs=tie_projs, pre_lnorm=True,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
cutoffs=cutoffs, attn_type=0).to(device)
print(sum(p.numel() for p in model.parameters()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment