Commit 1e0aadd5 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix in update norm calculation

parent 0f64f6ad
......@@ -441,7 +441,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.masked_scatter_(self._model_param_is_contrib, local_contrib_l2_norm)
torch.distributed.allreduce(l2_norm, group=self._ag_pg[0])
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
return l2_norm.masked_select(self._model_param_is_contrib)
def _pipeline_step(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment