"vscode:/vscode.git/clone" did not exist on "6a905be5ced93c46e35b675fbdc73d40bb95d3ee"
Commit d48218a0 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Modify fused_adam to take advantage of undo feature

parent c7372320
...@@ -61,6 +61,37 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -61,6 +61,37 @@ class FusedAdam(torch.optim.Optimizer):
super(FusedAdam, self).__init__(params, defaults) super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._overflow_buf.item()
self._overflow_buf.zero_()
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._overflow_buf.item()
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
"""Performs a single optimization step. """Performs a single optimization step.
...@@ -80,6 +111,13 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -80,6 +111,13 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
self._step(grads, output_params, scale, grad_norms, False)
self.strided_check_finite(output_params, output_params.numel(), 0, output_params.numel())
if self.peek_overflow:
self._step(grads, output_params, scale, grad_norms, True)
return loss
def _step(self, grads, output_params, scale., grad_norms, undo):
if hasattr(self, "_amp_stash"): if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads grads = self._amp_stash.grads
output_params = self._amp_stash.output_params output_params = self._amp_stash.output_params
...@@ -143,55 +181,94 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -143,55 +181,94 @@ class FusedAdam(torch.optim.Optimizer):
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if undo:
state['step'] = 0 assert (len(state) > 0), "Adam undo called with empty optimizer state"
# Exponential moving average of gradient values else:
state['exp_avg'] = torch.zeros_like(p.data) if len(state) == 0:
# Exponential moving average of squared gradient values state['step'] = 0
state['exp_avg_sq'] = torch.zeros_like(p.data) # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
state['step'] += 1 if undo:
step = state['step']
state['step'] -= 1
else:
state['step'] += 1
step = state['step']
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param if not undo:
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
if self._use_multi_tensor: if self._use_multi_tensor:
pl = [p.data, exp_avg, exp_avg_sq, grad] pl = [p.data, exp_avg, exp_avg_sq, grad]
if output_param is not None: if not undo and output_param is not None:
pl.append(out_p) pl.append(out_p)
for tl, t in zip(tensorlists, pl): for tl, t in zip(tensorlists, pl):
tl.append(t) tl.append(t)
else: else:
fused_adam_cuda.adam(p.data, if undo:
out_p, fused_adam_cuda.adam_undo(p.data,
exp_avg, exp_avg,
exp_avg_sq, exp_avg_sq,
grad, grad,
group['lr'], group['lr'],
beta1, beta1,
beta2, beta2,
group['eps'], group['eps'],
combined_scale, combined_scale,
state['step'], step,
self.eps_mode, self.eps_mode,
bias_correction, bias_correction,
group['weight_decay']) group['weight_decay'])
else:
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step,
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor: if self._use_multi_tensor:
multi_tensor_applier( if undo:
fused_adam_cuda.adam_mt, multi_tensor_applier(
self._overflow_buf, fused_adam_cuda.adam_undo_mt,
tensorlists, self._overflow_buf,
group['lr'], tensorlists,
beta1, group['lr'],
beta2, beta1,
group['eps'], beta2,
combined_scale, group['eps'],
state['step'], combined_scale,
self.eps_mode, step,
bias_correction, self.eps_mode,
group['weight_decay']) bias_correction,
group['weight_decay'])
else:
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step,
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment