Unverified Commit b37110e3 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

AdamW implementation minor fix (#261)

parent 48eeb3dc
...@@ -65,6 +65,12 @@ class AdamW(Optimizer): ...@@ -65,6 +65,12 @@ class AdamW(Optimizer):
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
# Perform stepweight decay
# p.data.mul_(1 - group['lr'] * group['weight_decay']) # AdamW
p.data.mul_(1 - group['weight_decay']) # Neurochem
# Perform optimization step
grad = p.grad.data grad = p.grad.data
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
...@@ -89,6 +95,8 @@ class AdamW(Optimizer): ...@@ -89,6 +95,8 @@ class AdamW(Optimizer):
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
state['step'] += 1 state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad)
...@@ -97,17 +105,12 @@ class AdamW(Optimizer): ...@@ -97,17 +105,12 @@ class AdamW(Optimizer):
# Maintains the maximum of all 2nd moment running avg. till now # Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient # Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps']) denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
else: else:
denom = exp_avg_sq.sqrt().add_(group['eps']) denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step'] step_size = group['lr'] / bias_correction1
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom) p.data.addcdiv_(-step_size, exp_avg, denom)
if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment