Commit 689e0b24 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix a bug for return attn_weights

parent a92c6297
...@@ -92,24 +92,30 @@ class SelfMultiheadAttention(nn.Module): ...@@ -92,24 +92,30 @@ class SelfMultiheadAttention(nn.Module):
) )
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_probs = softmax_dropout( if not return_attn:
attn_weights, self.dropout, self.training, bias=attn_bias attn = softmax_dropout(
) attn_weights, self.dropout, self.training, bias=attn_bias,
)
else:
attn_weights += attn_bias
attn = softmax_dropout(
attn_weights, self.dropout, self.training,
)
attn = torch.bmm(attn_probs, v) o = torch.bmm(attn, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = ( o = (
attn.view(bsz, self.num_heads, tgt_len, self.head_dim) o.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2) .transpose(1, 2)
.contiguous() .contiguous()
.view(bsz, tgt_len, embed_dim) .view(bsz, tgt_len, embed_dim)
) )
attn = self.out_proj(attn) o = self.out_proj(o)
if not return_attn: if not return_attn:
return attn return o
else: else:
return attn, attn_weights, attn_probs return o, attn_weights, attn
class CrossMultiheadAttention(nn.Module): class CrossMultiheadAttention(nn.Module):
...@@ -201,16 +207,16 @@ class CrossMultiheadAttention(nn.Module): ...@@ -201,16 +207,16 @@ class CrossMultiheadAttention(nn.Module):
) )
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training, bias=attn_bias) attn = softmax_dropout(attn_weights, self.dropout, self.training, bias=attn_bias)
attn = torch.bmm(attn_probs, v) o = torch.bmm(attn, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = ( o = (
attn.view(bsz, self.num_heads, tgt_len, self.head_dim) o.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2) .transpose(1, 2)
.contiguous() .contiguous()
.view(bsz, tgt_len, embed_dim) .view(bsz, tgt_len, embed_dim)
) )
attn = self.out_proj(attn) o = self.out_proj(o)
return attn return o
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment