"vscode:/vscode.git/clone" did not exist on "5c9e1e285e50c7be6cbcec04c47b4f0b929ede85"
Commit 9bb71066 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Revert regular contrib fused adam optimizer

parent 7e3536dd
...@@ -61,38 +61,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -61,38 +61,7 @@ class FusedAdam(torch.optim.Optimizer):
super(FusedAdam, self).__init__(params, defaults) super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0]) def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._overflow_buf.item()
self._overflow_buf.zero_()
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._overflow_buf.item()
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None, allow_undo=False):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -106,22 +75,11 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -106,22 +75,11 @@ class FusedAdam(torch.optim.Optimizer):
updated weights. Have to be of same type as gradients. (default: None) updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1) by before applying to weights. (default: 1)
allow_undo (bool, optional): allow use of undo feature. Internal buffers
will be restored to pre-step state if overflow is detected in gradient.
""" """
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
self._step(grads, output_params, scale, grad_norms, allow_undo, False)
if allow_undo and self.peek_overflow:
self._step(grads, output_params, scale, grad_norms, False, True)
return loss
def _step(self, grads, output_params, scale, grad_norms, check_overflow, undo):
if check_overflow:
modified_params = []
if hasattr(self, "_amp_stash"): if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads grads = self._amp_stash.grads
output_params = self._amp_stash.output_params output_params = self._amp_stash.output_params
...@@ -172,6 +130,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -172,6 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
tensorlists = [[],[],[],[],[]] tensorlists = [[],[],[],[],[]]
else: else:
tensorlists = [[],[],[],[]] tensorlists = [[],[],[],[]]
tensordevice = None
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group): for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
...@@ -185,53 +144,34 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -185,53 +144,34 @@ class FusedAdam(torch.optim.Optimizer):
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if undo: if len(state) == 0:
assert (len(state) > 0), "Adam undo called with empty optimizer state" state['step'] = 0
else: # Exponential moving average of gradient values
if len(state) == 0: state['exp_avg'] = torch.zeros_like(p.data)
state['step'] = 0 # Exponential moving average of squared gradient values
# Exponential moving average of gradient values state['exp_avg_sq'] = torch.zeros_like(p.data)
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
if undo: state['step'] += 1
step = state['step']
state['step'] -= 1
else:
state['step'] += 1
step = state['step']
if not undo: out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
if check_overflow:
modified_params.append(out_p)
if self._use_multi_tensor: if self._use_multi_tensor:
pl = [p.data, exp_avg, exp_avg_sq, grad] pl = [p.data, exp_avg, exp_avg_sq, grad]
if not undo and output_param is not None: if output_param is not None:
pl.append(out_p) pl.append(out_p)
for tl, t in zip(tensorlists, pl): for tl, t in zip(tensorlists, pl):
tl.append(t) tl.append(t)
if tensordevice is None:
tensordevice = p.device
elif tensordevice != p.device:
raise RuntimeError('FusedAdam does not support use_mt with tensors on multiple device')
else: else:
if undo: with torch.cuda.device(p.device):
fused_adam_cuda.adam_undo(p.data,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step,
self.eps_mode,
bias_correction,
group['weight_decay'])
else:
fused_adam_cuda.adam(p.data, fused_adam_cuda.adam(p.data,
out_p, out_p,
exp_avg, exp_avg,
...@@ -242,27 +182,13 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -242,27 +182,13 @@ class FusedAdam(torch.optim.Optimizer):
beta2, beta2,
group['eps'], group['eps'],
combined_scale, combined_scale,
step, state['step'],
self.eps_mode, self.eps_mode,
bias_correction, bias_correction,
group['weight_decay']) group['weight_decay'])
if self._use_multi_tensor: if self._use_multi_tensor:
if undo: with torch.cuda.device(tensordevice):
multi_tensor_applier(
fused_adam_cuda.adam_undo_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
step,
self.eps_mode,
bias_correction,
group['weight_decay'])
else:
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.adam_mt, fused_adam_cuda.adam_mt,
self._overflow_buf, self._overflow_buf,
...@@ -272,11 +198,9 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -272,11 +198,9 @@ class FusedAdam(torch.optim.Optimizer):
beta2, beta2,
group['eps'], group['eps'],
combined_scale, combined_scale,
step, state['step'],
self.eps_mode, self.eps_mode,
bias_correction, bias_correction,
group['weight_decay']) group['weight_decay'])
if check_overflow: return loss
for i, out_p in enumerate(modified_params):
self.strided_check_finite(out_p, stride=out_p.numel(), start=0, end=out_p.numel(), clear=True if i == 0 else False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment