Commit 7741808b authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 12458152
......@@ -96,7 +96,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import amp_C
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self.adam_w_mode = 1 if adam_w_mode else 0
self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb
self._is_accumulation_step = False
self._last_step = False
......@@ -465,7 +465,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_bias_correction,
self._param_state['step']+1,
self._contrib_epsilon,
1, # adam mode. FIXME: Correct value
self._adam_w_mode,
self._contrib_weight_decay,
self.L2_grad_norm,
max_grad_norm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment