Commit 689e0b24 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix a bug for return attn_weights

parent a92c6297
......@@ -92,24 +92,30 @@ class SelfMultiheadAttention(nn.Module):
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_probs = softmax_dropout(
attn_weights, self.dropout, self.training, bias=attn_bias
if not return_attn:
attn = softmax_dropout(
attn_weights, self.dropout, self.training, bias=attn_bias,
)
else:
attn_weights += attn_bias
attn = softmax_dropout(
attn_weights, self.dropout, self.training,
)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
o = torch.bmm(attn, v)
assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = (
attn.view(bsz, self.num_heads, tgt_len, self.head_dim)
o = (
o.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz, tgt_len, embed_dim)
)
attn = self.out_proj(attn)
o = self.out_proj(o)
if not return_attn:
return attn
return o
else:
return attn, attn_weights, attn_probs
return o, attn_weights, attn
class CrossMultiheadAttention(nn.Module):
......@@ -201,16 +207,16 @@ class CrossMultiheadAttention(nn.Module):
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training, bias=attn_bias)
attn = softmax_dropout(attn_weights, self.dropout, self.training, bias=attn_bias)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
o = torch.bmm(attn, v)
assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = (
attn.view(bsz, self.num_heads, tgt_len, self.head_dim)
o = (
o.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.contiguous()
.view(bsz, tgt_len, embed_dim)
)
attn = self.out_proj(attn)
return attn
o = self.out_proj(o)
return o
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment