Commit 5d1993cf authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Don't pad between consecutive parameters

parent e1a4deba
......@@ -101,6 +101,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._grads_info = []
for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
......@@ -119,8 +120,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# enforce 128b alignment (64 * fp16)
p_offset = ((p_offset + 63) // 64) * 64
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment