Commit e5c8e3c9 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add separate dwu_num_chunks argument

parent f2c9aa33
......@@ -45,7 +45,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False):
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False,
dwu_num_chunks=4):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
......@@ -76,6 +77,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None
......@@ -202,11 +204,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
works = [work]
if self._num_groups > 1:
sliver_size = self._shard_size // len(self._ar_pg)
sliver_size = self._shard_size // self._num_chunks
assert ((sliver_size*self._num_chunks) == self._shard_size), "Shard size not a multiple of dwu_num_chunks"
works = []
for i, ar_pg in enumerate(self._ar_pg):
work.wait()
works.append( torch.distributed.all_reduce(grad_shards[self._rank_in_group][i*sliver_size:(i+1)*sliver_size],group=ar_pg,async_op=True) )
work.wait()
for i in range(self._num_chunks):
works.append( torch.distributed.all_reduce(grad_shards[self._rank_in_group][i*sliver_size:(i+1)*sliver_size],group=self._ar_pg[i%len(self._ar_pg)],async_op=True) )
if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[0]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment