Unverified Commit 0106826a authored by Kwanghee Choi's avatar Kwanghee Choi Committed by GitHub
Browse files

Fix missing autocast() in Trainer.prediction_step() (#14075)


Co-authored-by: default avatarjonas <jonas@hpcnt.com>
parent a43d9352
......@@ -2486,6 +2486,10 @@ class Trainer:
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
if self.use_amp:
with autocast():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
if isinstance(outputs, dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment