Commit 9773218c authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 02fd7341
...@@ -64,7 +64,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -64,7 +64,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
""" """
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False, weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, amp_scale_adjustment=1.0, overlap_reductions=True,
...@@ -83,6 +83,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -83,6 +83,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm) max_grad_norm=max_grad_norm)
super(DistributedFusedLAMB, self).__init__(params, defaults) super(DistributedFusedLAMB, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
...@@ -128,7 +129,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -128,7 +129,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
1 if group['bias_correction'] else 0, 1 if group['bias_correction'] else 0,
beta1, beta1,
beta2, beta2,
1.0 - beta1 if group['grad_averaging'] else 1.0, 1.0 - beta1 if grad_averaging else 1.0,
group['eps'] group['eps']
)) ))
state = self.state[p] state = self.state[p]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment