"tests/distributed/test_distributed_sampling.py" did not exist on "ff5b5a4a1090f79e82163f43909a1831b7fce924"
Commit 5d1993cf authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Don't pad between consecutive parameters

parent e1a4deba
...@@ -101,6 +101,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -101,6 +101,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._grads_info = [] self._grads_info = []
for group in self.param_groups: for group in self.param_groups:
self._param_group = group self._param_group = group
prev = None
for p in group['params']: for p in group['params']:
torch.distributed.broadcast(p,0) torch.distributed.broadcast(p,0)
if not p.requires_grad: if not p.requires_grad:
...@@ -119,7 +120,10 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -119,7 +120,10 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset}) self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset) wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size p_offset += p_grads_size
# enforce 128b alignment (64 * fp16) # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64 p_offset = ((p_offset + 63) // 64) * 64
p_i += 1 p_i += 1
self._grads_generated = [False]*len(self._grads_info) self._grads_generated = [False]*len(self._grads_info)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment