Commit 14cc50d0 authored by Stas Bekman's avatar Stas Bekman
Browse files

fix autocast for older pytorch

parent 4c0dd199
......@@ -1857,8 +1857,12 @@ class Trainer:
return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
with autocast(dtype=self.amp_dtype):
loss = self.compute_loss(model, inputs)
else:
with autocast():
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs)
......@@ -2501,8 +2505,12 @@ class Trainer:
else:
if has_labels:
if self.use_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
with autocast(dtype=self.amp_dtype):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
with autocast():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
......@@ -2514,8 +2522,12 @@ class Trainer:
else:
loss = None
if self.use_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
with autocast(dtype=self.amp_dtype):
outputs = model(**inputs)
else:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
if isinstance(outputs, dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment