Commit 173c577b authored by Michael Auli's avatar Michael Auli Committed by Myle Ott
Browse files

Momentum correction

parent dd31fa92
......@@ -11,7 +11,7 @@ from torch.optim.optimizer import Optimizer, required
class NAG(Optimizer):
def __init__(self, params, lr=required, momentum=0, weight_decay=0):
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay)
super(NAG, self).__init__(params, defaults)
def step(self, closure=None):
......@@ -29,6 +29,8 @@ class NAG(Optimizer):
weight_decay = group['weight_decay']
momentum = group['momentum']
lr = group['lr']
lr_old = group.get('lr_old', lr)
lr_correct = lr / lr_old
for p in group['params']:
if p.grad is None:
......@@ -43,9 +45,11 @@ class NAG(Optimizer):
if weight_decay != 0:
p.data.mul_(1 - weight_decay)
p.data.add_(momentum * momentum, buf)
p.data.add_(momentum * momentum * lr_correct, buf)
p.data.add_(-(1 + momentum) * lr, d_p)
buf.mul_(momentum).add_(-lr, d_p)
buf.mul_(momentum * lr_correct).add_(-lr, d_p)
group['lr_old'] = lr
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment