Commit eb8384b5 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Modify fused_adam to take advantage of undo feature

parent f1e565f5
...@@ -92,7 +92,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -92,7 +92,7 @@ class FusedAdam(torch.optim.Optimizer):
stride, stride,
1 if clear else 0) 1 if clear else 0)
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None, allow_undo=False):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -106,12 +106,15 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -106,12 +106,15 @@ class FusedAdam(torch.optim.Optimizer):
updated weights. Have to be of same type as gradients. (default: None) updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1) by before applying to weights. (default: 1)
allow_undo (bool, optional): allow use of undo feature. Internal buffers
will be restored to pre-step state if overflow is detected in gradient.
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
self._step(grads, output_params, scale, grad_norms, False) self._step(grads, output_params, scale, grad_norms, False)
if allow_overflow:
self.strided_check_finite(output_params, output_params.numel(), 0, output_params.numel()) self.strided_check_finite(output_params, output_params.numel(), 0, output_params.numel())
if self.peek_overflow: if self.peek_overflow:
self._step(grads, output_params, scale, grad_norms, True) self._step(grads, output_params, scale, grad_norms, True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment