Commit 841e5ee1 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 1d4a95d4
......@@ -113,14 +113,15 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None:
loss = closure()
self._step(grads, output_params, scale, grad_norms, False)
if allow_overflow:
self.strided_check_finite(output_params, output_params.numel(), 0, output_params.numel())
if self.peek_overflow:
self._step(grads, output_params, scale, grad_norms, True)
self._step(grads, output_params, scale, grad_norms, allow_undo, False)
if allow_undo and self.peek_overflow:
self._step(grads, output_params, scale, grad_norms, False, True)
return loss
def _step(self, grads, output_params, scale., grad_norms, undo):
def _step(self, grads, output_params, scale., grad_norms, check_overflow, undo):
if check_overflow:
modified_params = []
if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
......@@ -206,6 +207,8 @@ class FusedAdam(torch.optim.Optimizer):
if not undo:
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
if check_overflow:
modified_params.append(out_p)
if self._use_multi_tensor:
pl = [p.data, exp_avg, exp_avg_sq, grad]
if not undo and output_param is not None:
......@@ -274,4 +277,8 @@ class FusedAdam(torch.optim.Optimizer):
bias_correction,
group['weight_decay'])
if check_overflow:
for i, out_p in enumerate(modified_params):
self.strided_check_finite(out_p, stride=out_p.numel(), start=0, end=out_p.numel(), clear=True if i == 0 else False)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment