Commit 9c82241d authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Remove implicit memcpy of grad tensor in do_overlapped function

parent 5d1993cf
......@@ -99,6 +99,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._param_state = None
self._model_params = []
self._grads_info = []
self._grad_accs = []
for group in self.param_groups:
self._param_group = group
prev = None
......@@ -114,9 +115,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
def allreduce_hook(grad):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, grad)
param.register_hook(allreduce_hook)
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
......@@ -160,8 +164,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, grads_info in enumerate(self._grads_info):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]])
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p):
def __blockify(p):
......@@ -412,12 +416,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad):
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (grad.view(-1), self._individual_flat_grads[param_i]) )
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size])
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step:
if self._overlap_reductions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment