Commit aa90d31f authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add internal pipelining option

parent be4c41c2
...@@ -46,7 +46,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -46,7 +46,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
compute_L2_grad_norm=False, distributed_weight_update=0, compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False, dwu_num_ag_pg=0, dwu_num_blk_st=1, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True): dwu_num_chunks=4, predivide=True, internal_pipeline=False):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
...@@ -79,6 +79,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -79,6 +79,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._num_blocks = dwu_num_blocks self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._predivide = predivide self._predivide = predivide
self._internal_pipeline = internal_pipeline
self._full_pipeline = full_pipeline self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None self._L2_grad_norm = torch.zeros([]).cuda() if self._compute_L2_grad_norm else None
...@@ -209,19 +210,33 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -209,19 +210,33 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grad_block = flat_grads[start:end] grad_block = flat_grads[start:end]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)] grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
if self._pg_supports_no_copy: if self._internal_pipeline:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True)
else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
works = [work]
if self._num_groups > 1:
work.wait()
works = [] works = []
chunk_size = self._shard_size // self._num_chunks chunk_size = self._shard_size // self._num_chunks
for i in range(self._num_chunks): for i in range(self._num_chunks):
chunks = [grad_shards[j][i*chunk_size:(i+1)*chunk_size] for j in range(self._group_size)] chunks = [grad_shards[j][i*chunk_size:(i+1)*chunk_size] for j in range(self._group_size)]
work = torch.distributed.all_reduce(chunks[self._rank_in_group],group=self._ar_pg[i%len(self._ar_pg)],async_op=True) if self._pg_supports_no_copy:
work = torch.distributed.reduce_scatter(chunks[self._rank_in_group],chunks,group=self._rs_pg[i%len(self._rs_pg)],async_op=True,no_copy=True)
else:
work = torch.distributed.reduce_scatter(chunks[self._rank_in_group],chunks,group=self._rs_pg[i%len(self._rs_pg)],async_op=True)
if self._num_groups > 1:
work.wait()
work = torch.distributed.all_reduce(chunks[self._rank_in_group],group=self._ar_pg[i%len(self._ar_pg)],async_op=True)
works.append(work) works.append(work)
else:
if self._pg_supports_no_copy:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True)
else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
works = [work]
if self._num_groups > 1:
work.wait()
works = []
chunk_size = self._shard_size // self._num_chunks
for i in range(self._num_chunks):
chunks = [grad_shards[j][i*chunk_size:(i+1)*chunk_size] for j in range(self._group_size)]
work = torch.distributed.all_reduce(chunks[self._rank_in_group],group=self._ar_pg[i%len(self._ar_pg)],async_op=True)
works.append(work)
if self._compute_L2_grad_norm: if self._compute_L2_grad_norm:
with torch.cuda.stream(self._blk_st[0]): with torch.cuda.stream(self._blk_st[0]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment