Commit 274d850d authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fix #4098

parent 26dad0a9
......@@ -145,7 +145,7 @@ class Attention(nn.Module):
w = w / (v.size(-1) ** 0.5)
nd, ns = w.size(-2), w.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask, w, self.masked_bias)
w = torch.where(mask, w, self.masked_bias.to(w.dtype))
if attention_mask is not None:
# Apply the attention mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment