"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "799f5b4e12c5350872b6fe5ebc28be423d2570c3"
Commit f2c9aa33 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add support for 4 all-reduce IB communicators

parent 5c1cf020
...@@ -199,12 +199,19 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -199,12 +199,19 @@ class DistributedFusedAdam(torch.optim.Optimizer):
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True) work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True)
else: else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True) work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
works = [work]
if self._num_groups > 1: if self._num_groups > 1:
work.wait() sliver_size = self._shard_size // len(self._ar_pg)
work = torch.distributed.all_reduce(grad_shards[self._rank_in_group],group=self._ar_pg[block_id%len(self._ar_pg)],async_op=True) works = []
for i, ar_pg in enumerate(self._ar_pg):
work.wait()
works.append( torch.distributed.all_reduce(grad_shards[self._rank_in_group][i*sliver_size:(i+1)*sliver_size],group=ar_pg,async_op=True) )
if self._compute_L2_grad_norm: if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[0]): with torch.cuda.stream(self._blk_st[0]):
work.wait() for work in works:
work.wait()
if block_id+1 == self._num_blocks: if block_id+1 == self._num_blocks:
self._L2_grad_norm = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2 self._L2_grad_norm = grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
elif block_id != 0: elif block_id != 0:
...@@ -213,7 +220,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -213,7 +220,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._L2_grad_norm += grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2 self._L2_grad_norm += grad_shards[self._rank_in_group].norm(dtype=torch.float32,p=2)**2
torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[0]) torch.distributed.all_reduce(self._L2_grad_norm,group=self._rs_pg[0])
self._L2_grad_norm.sqrt_() self._L2_grad_norm.sqrt_()
return work
return works
# NB! # NB!
# self._global_scale is used by this method. # self._global_scale is used by this method.
...@@ -229,9 +237,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -229,9 +237,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
return work return work
def _pipeline_block(self, block_id, flat_grads, new_params): def _pipeline_block(self, block_id, flat_grads, new_params):
work = self._pipeline_block_reductions(block_id, flat_grads) works = self._pipeline_block_reductions(block_id, flat_grads)
if work is not None: for work in works:
work.wait() if work is not None:
work.wait()
return self._pipeline_block_step(block_id, flat_grads, new_params) return self._pipeline_block_step(block_id, flat_grads, new_params)
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad): def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
...@@ -251,8 +260,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -251,8 +260,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
work = self._pipeline_block(block_id, self._flat_grads, self._new_params) work = self._pipeline_block(block_id, self._flat_grads, self._new_params)
self._works.append(work) self._works.append(work)
else: else:
work = self._pipeline_block_reductions(block_id, self._flat_grads) works = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works.append(work) self._works += works
flush_block = self._get_flush_block() flush_block = self._get_flush_block()
...@@ -463,8 +472,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -463,8 +472,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
block_id = self._num_blocks - inv_block_id - 1 block_id = self._num_blocks - inv_block_id - 1
self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream()) self._blk_st[block_id%len(self._blk_st)].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]): with torch.cuda.stream(self._blk_st[block_id%len(self._blk_st)]):
work = self._pipeline_block_reductions(block_id, self._flat_grads) works = self._pipeline_block_reductions(block_id, self._flat_grads)
self._works.append(work) self._works += works
self._copy_to_fp32 = False self._copy_to_fp32 = False
self._decomp_stats = None self._decomp_stats = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment