Commit 724672d7 authored by Simon Layton's avatar Simon Layton
Browse files

Fix momentum initialization with weight decay

parent b265b0b5
...@@ -108,7 +108,6 @@ class SGD(Optimizer): ...@@ -108,7 +108,6 @@ class SGD(Optimizer):
if 'momentum_buffer' not in param_state: if 'momentum_buffer' not in param_state:
first_run = True first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(p.grad.data)
momentums.append(buf) momentums.append(buf)
else: else:
first_run = False first_run = False
......
...@@ -102,6 +102,9 @@ struct SGDFunctor ...@@ -102,6 +102,9 @@ struct SGDFunctor
if (momentum != 0.f) { if (momentum != 0.f) {
if (!first_run) { if (!first_run) {
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii]; incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
} else {
// initialize momentume to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
} }
if (nesterov) { if (nesterov) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment