Unverified Commit d9a46fde authored by Nan Zheng's avatar Nan Zheng Committed by GitHub
Browse files

Fix dist lamb (#1185)

1. remove the weight broadcast in the constructor
2. disable unnecessary allreduces for clip-after-ar
parent 4e9fae9b
...@@ -270,7 +270,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -270,7 +270,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
eps = group['eps'] eps = group['eps']
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
for p in group['params']: for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad: if not p.requires_grad:
continue continue
self._model_params.append(p) self._model_params.append(p)
...@@ -729,12 +728,14 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -729,12 +728,14 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
# check global_grad_norm and fill overflow_buf # check global_grad_norm and fill overflow_buf
is_finite = (global_grad_norm + 1 > global_grad_norm).int() is_finite = (global_grad_norm + 1 > global_grad_norm).int()
self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1 self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
torch.distributed.all_reduce(is_finite,
op=torch.distributed.ReduceOp.MIN, if not self._clip_after_ar:
group=self._current_process_group) torch.distributed.all_reduce(is_finite,
torch.distributed.all_reduce(self._overflow_buf, op=torch.distributed.ReduceOp.MIN,
op=torch.distributed.ReduceOp.MAX, group=self._current_process_group)
group=self._current_process_group) torch.distributed.all_reduce(self._overflow_buf,
op=torch.distributed.ReduceOp.MAX,
group=self._current_process_group)
# increment step counter if no overflow # increment step counter if no overflow
self._step += is_finite self._step += is_finite
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment