Commit 724672d7 authored by Simon Layton's avatar Simon Layton
Browse files

Fix momentum initialization with weight decay

parent b265b0b5
......@@ -108,7 +108,6 @@ class SGD(Optimizer):
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(p.grad.data)
momentums.append(buf)
else:
first_run = False
......
......@@ -102,6 +102,9 @@ struct SGDFunctor
if (momentum != 0.f) {
if (!first_run) {
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
} else {
// initialize momentume to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
}
if (nesterov) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment