Commit 2622d7f1 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Use glob_chunk to index streams and process groups

parent 85497632
...@@ -222,17 +222,18 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -222,17 +222,18 @@ class DistributedFusedAdam(torch.optim.Optimizer):
works = [None]*self._num_chunks works = [None]*self._num_chunks
for chunk in range(self._num_chunks): for chunk in range(self._num_chunks):
glob_chunk = block_id * self._num_chunks + chunk
grad_chunk = grad_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size] grad_chunk = grad_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
grad_shards = [grad_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)] grad_shards = [grad_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
rs_stream = self._rs_st[chunk%self._num_rs_pg] rs_stream = self._rs_st[glob_chunk%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream()) rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream): with torch.cuda.stream(rs_stream):
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[chunk%self._num_rs_pg],async_op=True,no_copy=True) work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[glob_chunk%self._num_rs_pg],async_op=True,no_copy=True)
if self._num_groups > 1: if self._num_groups > 1:
ar_stream = self._ar_st[chunk%self._num_ar_pg] ar_stream = self._ar_st[glob_chunk%self._num_ar_pg]
with torch.cuda.stream(ar_stream): with torch.cuda.stream(ar_stream):
work.wait() work.wait()
work = torch.distributed.all_reduce(grad_shards[self._rank_in_group],group=self._ar_pg[chunk%self._num_ar_pg],async_op=True) work = torch.distributed.all_reduce(grad_shards[self._rank_in_group],group=self._ar_pg[glob_chunk%self._num_ar_pg],async_op=True)
works[chunk] = work works[chunk] = work
if self._compute_L2_grad_norm: if self._compute_L2_grad_norm:
...@@ -262,13 +263,14 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -262,13 +263,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
works = [None]*self._num_chunks works = [None]*self._num_chunks
for chunk in range(self._num_chunks): for chunk in range(self._num_chunks):
glob_chunk = block_id * self._num_chunks + chunk
new_params_chunk = new_params_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size] new_params_chunk = new_params_block[chunk*self._chunk_size:(chunk+1)*self._chunk_size]
new_params_shards = [new_params_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)] new_params_shards = [new_params_chunk[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
ag_stream = self._ag_st[chunk%self._num_ag_pg] ag_stream = self._ag_st[glob_chunk%self._num_ag_pg]
with torch.cuda.stream(ag_stream): with torch.cuda.stream(ag_stream):
self._reductions_works[block_id][chunk].wait() self._reductions_works[block_id][chunk].wait()
self._partial_step_single_shard(block_id,chunk) self._partial_step_single_shard(block_id,chunk)
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[chunk%self._num_ag_pg],async_op=True,no_copy=True) work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[glob_chunk%self._num_ag_pg],async_op=True,no_copy=True)
works[chunk] = work works[chunk] = work
self._allgather_works[block_id] = works self._allgather_works[block_id] = works
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment