You need to sign in or sign up before continuing.
Commit 9ffda216 authored by Simon Layton's avatar Simon Layton
Browse files

Fix missed head transpose

parent d51b5894
......@@ -284,7 +284,7 @@ class XLNetRelativeAttention(nn.Module):
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * head_mask
attn_prob = attn_prob * torch.einsum('ijbn->bnij', head_mask)
# attention output
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment