"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "df7cd9e4e40a8c89f7570fd6cb3dbd30a803123a"
Commit 14cc50d0 authored by Stas Bekman's avatar Stas Bekman
Browse files

fix autocast for older pytorch

parent 4c0dd199
...@@ -1857,8 +1857,12 @@ class Trainer: ...@@ -1857,8 +1857,12 @@ class Trainer:
return loss_mb.reduce_mean().detach().to(self.args.device) return loss_mb.reduce_mean().detach().to(self.args.device)
if self.use_amp: if self.use_amp:
with autocast(dtype=self.amp_dtype): if version.parse(torch.__version__) >= version.parse("1.10"):
loss = self.compute_loss(model, inputs) with autocast(dtype=self.amp_dtype):
loss = self.compute_loss(model, inputs)
else:
with autocast():
loss = self.compute_loss(model, inputs)
else: else:
loss = self.compute_loss(model, inputs) loss = self.compute_loss(model, inputs)
...@@ -2501,8 +2505,12 @@ class Trainer: ...@@ -2501,8 +2505,12 @@ class Trainer:
else: else:
if has_labels: if has_labels:
if self.use_amp: if self.use_amp:
with autocast(dtype=self.amp_dtype): if version.parse(torch.__version__) >= version.parse("1.10"):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) with autocast(dtype=self.amp_dtype):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else:
with autocast():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
else: else:
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach() loss = loss.mean().detach()
...@@ -2514,8 +2522,12 @@ class Trainer: ...@@ -2514,8 +2522,12 @@ class Trainer:
else: else:
loss = None loss = None
if self.use_amp: if self.use_amp:
with autocast(dtype=self.amp_dtype): if version.parse(torch.__version__) >= version.parse("1.10"):
outputs = model(**inputs) with autocast(dtype=self.amp_dtype):
outputs = model(**inputs)
else:
with autocast():
outputs = model(**inputs)
else: else:
outputs = model(**inputs) outputs = model(**inputs)
if isinstance(outputs, dict): if isinstance(outputs, dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment