Unverified Commit 6060b2f8 authored by ziliwang's avatar ziliwang Committed by GitHub
Browse files

fix: hard coding for max number

fp16 max number is 65504, the original 1e30 will cause Nan in fp16
parent caf1d116
......@@ -418,7 +418,10 @@ class XLNetRelativeAttention(nn.Module):
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
attn_score = attn_score - 1e30 * attn_mask
if attn_mask.dtype == torch.float16:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
# attention probability
attn_prob = F.softmax(attn_score, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment