Commit 6eca2389 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Pragmatic change, seems like WAR for NCCL crash

parent 2c744ee5
......@@ -68,8 +68,6 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._overflow_buf = torch.cuda.IntTensor([0])
assert (not flat_mt), "flat_mt option is not safe in this version"
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
......@@ -167,7 +165,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
#self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
self._l2_grad_norm_pg = self._rs_pg[-1]
self._rs_st = [torch.cuda.Stream()]*self._num_rs_pg
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment