Commit 6eca2389 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Pragmatic change, seems like WAR for NCCL crash

parent 2c744ee5
...@@ -68,8 +68,6 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -68,8 +68,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._overflow_buf = torch.cuda.IntTensor([0]) self._overflow_buf = torch.cuda.IntTensor([0])
assert (not flat_mt), "flat_mt option is not safe in this version"
# Way to revert a step # Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference) # 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters # 2 -> double buffer fp32 parameters
...@@ -167,7 +165,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -167,7 +165,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp) self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks: if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) #self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
self._l2_grad_norm_pg = self._rs_pg[-1]
self._rs_st = [torch.cuda.Stream()]*self._num_rs_pg self._rs_st = [torch.cuda.Stream()]*self._num_rs_pg
if self._num_ag_pg == 0: if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg self._ag_pg = self._rs_pg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment