Commit c7372320 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Add backwards compatible support for no inplace NCCL op

parent 400cf628
......@@ -153,6 +153,14 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._blk_st.append(torch.cuda.Stream())
self._works = []
import inspect
if if 'inplace' in inspect.getfullargspec(torch.distributed.reduce_scatter).args:
self._pg_supports_inplace = True
else:
self._pg_supports_inplace = False
print("WARNING! torch.distributed.reduce_scatter does not support inplace op.")
def set_last_step(self, last_step):
self._last_step = last_step
......@@ -180,7 +188,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end = start + self._block_size
grad_block = flat_grads[start:end]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,inplace=True)
if self._pg_supports_inplace:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,inplace=True)
else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
if self._num_groups > 1:
work.wait()
work = torch.distributed.all_reduce(grad_shards[self._rank_in_group],group=self._ar_pg[block_id%len(self._ar_pg)],async_op=True)
......@@ -199,7 +210,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
shard_end = shard_start + self._shard_size
block_id = start // self._block_size
self._partial_step_single_shard(block_id)
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,inplace=True)
if self._pg_supports_inplace:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,inplace=True)
else:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True)
return work
def _pipeline_block(self, block_id, flat_grads, new_params):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment