Commit f1e565f5 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Modify fused_adam to take advantage of undo feature

parent c659e564
...@@ -61,6 +61,37 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -61,6 +61,37 @@ class FusedAdam(torch.optim.Optimizer):
super(FusedAdam, self).__init__(params, defaults) super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._overflow_buf.item()
self._overflow_buf.zero_()
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._overflow_buf.item()
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
"""Performs a single optimization step. """Performs a single optimization step.
...@@ -80,6 +111,13 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -80,6 +111,13 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
self._step(grads, output_params, scale, grad_norms, False)
self.strided_check_finite(output_params, output_params.numel(), 0, output_params.numel())
if self.peek_overflow:
self._step(grads, output_params, scale, grad_norms, True)
return loss
def _step(self, grads, output_params, scale., grad_norms, undo):
if hasattr(self, "_amp_stash"): if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads grads = self._amp_stash.grads
output_params = self._amp_stash.output_params output_params = self._amp_stash.output_params
...@@ -143,6 +181,9 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -143,6 +181,9 @@ class FusedAdam(torch.optim.Optimizer):
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if undo:
assert (len(state) > 0), "Adam undo called with empty optimizer state"
else:
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
...@@ -153,16 +194,37 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -153,16 +194,37 @@ class FusedAdam(torch.optim.Optimizer):
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
if undo:
step = state['step']
state['step'] -= 1
else:
state['step'] += 1 state['step'] += 1
step = state['step']
if not undo:
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
if self._use_multi_tensor: if self._use_multi_tensor:
pl = [p.data, exp_avg, exp_avg_sq, grad] pl = [p.data, exp_avg, exp_avg_sq, grad]
if output_param is not None: if not undo and output_param is not None:
pl.append(out_p) pl.append(out_p)
for tl, t in zip(tensorlists, pl): for tl, t in zip(tensorlists, pl):
tl.append(t) tl.append(t)
else:
if undo:
fused_adam_cuda.adam_undo(p.data,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step,
self.eps_mode,
bias_correction,
group['weight_decay'])
else: else:
fused_adam_cuda.adam(p.data, fused_adam_cuda.adam(p.data,
out_p, out_p,
...@@ -174,12 +236,27 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -174,12 +236,27 @@ class FusedAdam(torch.optim.Optimizer):
beta2, beta2,
group['eps'], group['eps'],
combined_scale, combined_scale,
state['step'], step,
self.eps_mode, self.eps_mode,
bias_correction, bias_correction,
group['weight_decay']) group['weight_decay'])
if self._use_multi_tensor: if self._use_multi_tensor:
if undo:
multi_tensor_applier(
fused_adam_cuda.adam_undo_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step,
self.eps_mode,
bias_correction,
group['weight_decay'])
else:
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.adam_mt, fused_adam_cuda.adam_mt,
self._overflow_buf, self._overflow_buf,
...@@ -189,7 +266,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -189,7 +266,7 @@ class FusedAdam(torch.optim.Optimizer):
beta2, beta2,
group['eps'], group['eps'],
combined_scale, combined_scale,
state['step'], step,
self.eps_mode, self.eps_mode,
bias_correction, bias_correction,
group['weight_decay']) group['weight_decay'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment