Unverified Commit ec94f4e0 authored by Simon Layton's avatar Simon Layton Committed by GitHub
Browse files

Fix fp16 masking in PoolerEndLogits

Necessary to run xlnet (at least in squad) with `--fp16 --fp16_opt_level="O2"`, otherwise loss is immediately `NaN` and fine-tuning cannot proceed.
parent 32e1332a
......@@ -478,6 +478,9 @@ class PoolerEndLogits(nn.Module):
x = self.dense_1(x).squeeze(-1)
if p_mask is not None:
if next(self.parameters()).dtype == torch.float16:
x = x * (1 - p_mask) - 65500 * p_mask
else:
x = x * (1 - p_mask) - 1e30 * p_mask
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment