Commit eb8384b5 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Modify fused_adam to take advantage of undo feature

parent f1e565f5
......@@ -92,7 +92,7 @@ class FusedAdam(torch.optim.Optimizer):
stride,
1 if clear else 0)
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None, allow_undo=False):
"""Performs a single optimization step.
Arguments:
......@@ -106,15 +106,18 @@ class FusedAdam(torch.optim.Optimizer):
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
allow_undo (bool, optional): allow use of undo feature. Internal buffers
will be restored to pre-step state if overflow is detected in gradient.
"""
loss = None
if closure is not None:
loss = closure()
self._step(grads, output_params, scale, grad_norms, False)
self.strided_check_finite(output_params, output_params.numel(), 0, output_params.numel())
if self.peek_overflow:
self._step(grads, output_params, scale, grad_norms, True)
if allow_overflow:
self.strided_check_finite(output_params, output_params.numel(), 0, output_params.numel())
if self.peek_overflow:
self._step(grads, output_params, scale, grad_norms, True)
return loss
def _step(self, grads, output_params, scale., grad_norms, undo):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment