Commit 9773218c authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 02fd7341
......@@ -64,7 +64,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True,
......@@ -83,6 +83,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(DistributedFusedLAMB, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
......@@ -128,7 +129,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
1 if group['bias_correction'] else 0,
beta1,
beta2,
1.0 - beta1 if group['grad_averaging'] else 1.0,
1.0 - beta1 if grad_averaging else 1.0,
group['eps']
))
state = self.state[p]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment