Commit da11cb76 authored by Rick Ho's avatar Rick Ho
Browse files

broadcast distributed parameters

parent afd43f51
......@@ -74,7 +74,36 @@ class DistributedGroupedDataParallel(nn.Module):
g.copy_(s)
self.allreduce_params = allreduce_params
self._sync_params()
def _sync_params(self):
groups = dict()
for p in self.module.parameters():
if not p.requires_grad or p.grad is None:
continue
if hasattr(p, 'dp_comm'):
dp_comm = p.dp_comm
else:
dp_comm = 'dp'
group_key = (dp_comm, p.dtype)
if group_key not in groups:
groups[group_key] = [p]
else:
groups[group_key].append(p)
for (dp_comm, dtype), group in groups.items():
if dp_comm not in self.comms:
continue
comm = self.comms[dp_comm]
datas = [p.data for p in group]
coalesced = _flatten_dense_tensors(datas)
if fp32_allreduce and dtype != torch.float32:
coalesced = coalesced.float()
torch.distributed.broadcast(coalesced, 0, group=comm)
torch.cuda.synchronize()
synced = _unflatten_dense_tensors(coalesced, datas)
for d, s in zip(datas, synced):
d.copy_(s)
def forward(self, *args, **kwargs):
r'''
Directly call the module's forward function.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment