Commit 9c82241d authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Remove implicit memcpy of grad tensor in do_overlapped function

parent 5d1993cf
...@@ -99,6 +99,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -99,6 +99,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._param_state = None self._param_state = None
self._model_params = [] self._model_params = []
self._grads_info = [] self._grads_info = []
self._grad_accs = []
for group in self.param_groups: for group in self.param_groups:
self._param_group = group self._param_group = group
prev = None prev = None
...@@ -114,9 +115,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -114,9 +115,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._param_state = state self._param_state = state
p_grads_size = p.numel() p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset): def wrapper(param, param_i, param_grads_size, param_offset):
def allreduce_hook(grad): param_tmp = param.expand_as(param)
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, grad) grad_acc = param_tmp.grad_fn.next_functions[0][0]
param.register_hook(allreduce_hook) def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset}) self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset) wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size p_offset += p_grads_size
...@@ -160,8 +164,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -160,8 +164,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda') self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = [] self._individual_flat_grads = []
for p_i, grads_info in enumerate(self._grads_info): for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]]) self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p): def _flat_split(p):
def __blockify(p): def __blockify(p):
...@@ -412,12 +416,12 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -412,12 +416,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
scale) scale)
self._grads = [] self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, grad): def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions # handle overlapped reductions
if self._flat_mt: if self._flat_mt:
self._grads.append( (grad.view(-1), self._individual_flat_grads[param_i]) ) self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else: else:
torch.div(grad.view(-1), self._world_size if self._predivide else 1.0, out=self._flat_grads[param_offset:param_offset+param_grads_size]) torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True self._grads_generated[param_i]=True
if not self._last_step: if not self._last_step:
if self._overlap_reductions: if self._overlap_reductions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment