Unverified Commit 461e8cac authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix evaluation with label smoothing in Trainer (#10338)

parent 622a8c59
......@@ -1888,6 +1888,14 @@ class Trainer:
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
with torch.no_grad():
if has_labels:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
......@@ -1918,13 +1926,6 @@ class Trainer:
if len(logits) == 1:
logits = logits[0]
if has_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
return (loss, logits, labels)
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment