Unverified Commit cd665396 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Don't modify labels inplace in `LabelSmoother` (#13464)

parent c164c651
...@@ -458,7 +458,7 @@ class LabelSmoother: ...@@ -458,7 +458,7 @@ class LabelSmoother:
padding_mask = labels.eq(self.ignore_index) padding_mask = labels.eq(self.ignore_index)
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
# will ignore them in any case. # will ignore them in any case.
labels.clamp_min_(0) labels = torch.clamp(labels, min=0)
nll_loss = log_probs.gather(dim=-1, index=labels) nll_loss = log_probs.gather(dim=-1, index=labels)
# works for fp16 input tensor too, by internally upcasting it to fp32 # works for fp16 input tensor too, by internally upcasting it to fp32
smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32) smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment