Commit cef660ba authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 6eca2389
......@@ -153,7 +153,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream()]*self._num_ar_pg
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
......@@ -167,7 +167,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
#self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
self._l2_grad_norm_pg = self._rs_pg[-1]
self._rs_st = [torch.cuda.Stream()]*self._num_rs_pg
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
......@@ -180,7 +180,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream()]*self._num_ag_pg
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment