Unverified Commit 0106826a authored by Kwanghee Choi's avatar Kwanghee Choi Committed by GitHub
Browse files

Fix missing autocast() in Trainer.prediction_step() (#14075)


Co-authored-by: default avatarjonas <jonas@hpcnt.com>
parent a43d9352
...@@ -2486,6 +2486,10 @@ class Trainer: ...@@ -2486,6 +2486,10 @@ class Trainer:
logits = smp_nested_concat(logits_mb) logits = smp_nested_concat(logits_mb)
else: else:
if has_labels: if has_labels:
if self.use_amp:
with autocast():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach() loss = loss.mean().detach()
if isinstance(outputs, dict): if isinstance(outputs, dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment