Unverified Commit e21f89f6 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix nan in full-fp16 label_smoothing eval (#10815)

parent b5b957a6
......@@ -433,7 +433,8 @@ class LabelSmoother:
# will ignore them in any case.
labels.clamp_min_(0)
nll_loss = log_probs.gather(dim=-1, index=labels)
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)
# works for fp16 input tensor too, by internally upcasting it to fp32
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
nll_loss.masked_fill_(padding_mask, 0.0)
smoothed_loss.masked_fill_(padding_mask, 0.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment