Commit eb8384b5 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Modify fused_adam to take advantage of undo feature

parent f1e565f5
...@@ -92,7 +92,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -92,7 +92,7 @@ class FusedAdam(torch.optim.Optimizer):
stride, stride,
1 if clear else 0) 1 if clear else 0)
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None, allow_undo=False):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -106,15 +106,18 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -106,15 +106,18 @@ class FusedAdam(torch.optim.Optimizer):
updated weights. Have to be of same type as gradients. (default: None) updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1) by before applying to weights. (default: 1)
allow_undo (bool, optional): allow use of undo feature. Internal buffers
will be restored to pre-step state if overflow is detected in gradient.
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
self._step(grads, output_params, scale, grad_norms, False) self._step(grads, output_params, scale, grad_norms, False)
self.strided_check_finite(output_params, output_params.numel(), 0, output_params.numel()) if allow_overflow:
if self.peek_overflow: self.strided_check_finite(output_params, output_params.numel(), 0, output_params.numel())
self._step(grads, output_params, scale, grad_norms, True) if self.peek_overflow:
self._step(grads, output_params, scale, grad_norms, True)
return loss return loss
def _step(self, grads, output_params, scale., grad_norms, undo): def _step(self, grads, output_params, scale., grad_norms, undo):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment