Commit 3ccdfaa3 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 3bae8c83
...@@ -264,9 +264,11 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -264,9 +264,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks) grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp) self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks: if self._compute_L2_grad_norm:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg) if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)] self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg: for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg) torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment