Unverified Commit 89fa152b authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Move other logic after forward to take advantage of GPU skew

parent 9d731777
......@@ -324,6 +324,8 @@ class DistributedDataParallel(Module):
def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs)
if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
......@@ -354,4 +356,4 @@ class DistributedDataParallel(Module):
self.callback_queued = False
return self.module(*inputs, **kwargs)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment