Commit f2c9aa33 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add support for 4 all-reduce IB communicators

parent 5c1cf020
......@@ -199,11 +199,18 @@ class DistributedFusedAdam(torch.optim.Optimizer):
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True)
else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
works = [work]
if self._num_groups > 1:
sliver_size = self._shard_size // len(self._ar_pg)
works = []
for i, ar_pg in enumerate(self._ar_pg):
work.wait()
work = torch.distributed.all_reduce(grad_shards[self._rank_in_group],group=self._ar_pg[block_id%len(self._ar_pg)],async_op=True)
works.append( torch.distributed.all_reduce(grad_shards[self._rank_in_group][i*sliver_size:(i+1)*sliver_size],group=ar_pg,async_op=True) )
if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[0]):
for work in works:
work.wait()
if block_id+1 == self._num_blocks:
self._L2_grad_norm = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
......@@ -213,7 +220,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._L2_grad_norm += grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[0])
self._L2_grad_norm.sqrt_()
return work
return works
# NB!
# self._global_scale is used by this method.
......@@ -229,7 +237,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
return work
def _pipeline_block(self, block_id, flat_grads, new_params):
work = self._pipeline_block_reductions(block_id, flat_grads)
works = self._pipeline_block_reductions(block_id, flat_grads)
for work in works:
if work is not None:
work.wait()
return self._pipeline_block_step(block_id, flat_grads, new_params)
......@@ -251,8 +260,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
work = self._pipeline_block(block_id, self._flat_grads, self._new_params)
self._works.append(work)
else:
work = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works.append(work)
works = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works += works
flush_block = self._get_flush_block()
......@@ -463,8 +472,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works.append(work)
works = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works += works
self._copy_to_fp32 = False
self._decomp_stats = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment