Commit eca3a2c4 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Change inplace to no_copy

parent 9f6c0da5
...@@ -154,11 +154,11 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -154,11 +154,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._works = [] self._works = []
import inspect import inspect
if if 'inplace' in inspect.getfullargspec(torch.distributed.reduce_scatter).args: if 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args:
self._pg_supports_inplace = True self._pg_supports_no_copy = True
else: else:
self._pg_supports_inplace = False self._pg_supports_no_copy = False
print("WARNING! torch.distributed.reduce_scatter does not support inplace op.") print("WARNING! torch.distributed.reduce_scatter does not support no_copy op.")
def set_last_step(self, last_step): def set_last_step(self, last_step):
...@@ -188,8 +188,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -188,8 +188,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
end = start + self._block_size end = start + self._block_size
grad_block = flat_grads[start:end] grad_block = flat_grads[start:end]
grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)] grad_shards = [grad_block[i*self._shard_size:(i+1)*self._shard_size] for i in range(self._group_size)]
if self._pg_supports_inplace: if self._pg_supports_no_copy:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,inplace=True) work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True,no_copy=True)
else: else:
work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True) work = torch.distributed.reduce_scatter(grad_shards[self._rank_in_group],grad_shards,group=self._rs_pg[block_id%len(self._rs_pg)],async_op=True)
if self._num_groups > 1: if self._num_groups > 1:
...@@ -210,8 +210,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -210,8 +210,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
shard_end = shard_start + self._shard_size shard_end = shard_start + self._shard_size
block_id = start // self._block_size block_id = start // self._block_size
self._partial_step_single_shard(block_id) self._partial_step_single_shard(block_id)
if self._pg_supports_inplace: if self._pg_supports_no_copy:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,inplace=True) work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True,no_copy=True)
else: else:
work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True) work = torch.distributed.all_gather(new_params_shards,new_params_shards[self._rank_in_group],group=self._ag_pg[block_id%len(self._ag_pg)],async_op=True)
return work return work
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment