Commit 9430544a authored by Yann N. Dauphin's avatar Yann N. Dauphin Committed by Sergey Edunov
Browse files

Directly decay weight instead of L2 penalty (#157)

See https://arxiv.org/pdf/1711.05101.pdf
parent 94dae690
...@@ -35,15 +35,14 @@ class NAG(Optimizer): ...@@ -35,15 +35,14 @@ class NAG(Optimizer):
continue continue
d_p = p.grad.data d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
param_state = self.state[p] param_state = self.state[p]
if 'momentum_buffer' not in param_state: if 'momentum_buffer' not in param_state:
param_state['momentum_buffer'] = d_p.clone().zero_() param_state['momentum_buffer'] = d_p.clone().zero_()
buf = param_state['momentum_buffer'] buf = param_state['momentum_buffer']
if weight_decay != 0:
p.data.mul_(1 - weight_decay)
p.data.add_(momentum * momentum, buf) p.data.add_(momentum * momentum, buf)
p.data.add_(-(1 + momentum) * lr, d_p) p.data.add_(-(1 + momentum) * lr, d_p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment