Commit 12458152 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add more default values from regular lamb optimizer

parent 1e0aadd5
...@@ -66,7 +66,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -66,7 +66,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, grad_averaging=True, lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False, weight_decay=0., max_grad_norm=0., amsgrad=False,
adam_w_mode=True, use_nvlamb=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, amp_scale_adjustment=1.0, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
...@@ -95,6 +96,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -95,6 +96,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import amp_C import amp_C
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self.adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb
self._is_accumulation_step = False self._is_accumulation_step = False
self._last_step = False self._last_step = False
self._overlap_reductions = overlap_reductions self._overlap_reductions = overlap_reductions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment