Commit 958c714e authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

add mem2mem attention

parent 44781ed2
...@@ -66,6 +66,89 @@ class PositionwiseFF(nn.Module): ...@@ -66,6 +66,89 @@ class PositionwiseFF(nn.Module):
return output return output
class ExtendedMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False):
super(MultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
def forward(self, h, attn_mask=None, mems=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
mem_len = mems.size(0)
else:
c = h
mem_len = 0
if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)
head_q = self.q_net(c)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score[mem_len:].masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score[mem_len:].masked_fill_(attn_mask[:,:,:,None], -float('inf'))
mem2other_attn = attn_mask.ones(mem_len, c.size(0))
mem2other_attn[:, :mem_len] = 0
attn_score[:mem_len].masked_fill_(mem2other_attn, -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
attn_vec = attn_vec[mem_len:]
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = h + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(h + attn_out)
return output
class MultiHeadAttn(nn.Module): class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False): pre_lnorm=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment