Commit 6fe415e9 authored by Rick Ho's avatar Rick Ho
Browse files

fix bug for sync_params

parent 21a331fb
...@@ -78,16 +78,8 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -78,16 +78,8 @@ class DistributedGroupedDataParallel(nn.Module):
self._sync_params() self._sync_params()
def _sync_params(self): def _sync_params(self):
r"""
Note that this module does not guarantee initial consistency of
parameters. Users are supposed to manually initalize the model on
different workers with the same parameters using either this function
or other methods like pre-defined random seeds.
"""
groups = dict() groups = dict()
for p in self.module.parameters(): for p in self.module.parameters():
if not p.requires_grad or p.grad is None:
continue
if hasattr(p, "dp_comm"): if hasattr(p, "dp_comm"):
dp_comm = p.dp_comm dp_comm = p.dp_comm
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment