Unverified Commit 89fa152b authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Move other logic after forward to take advantage of GPU skew

parent 9d731777
...@@ -324,6 +324,8 @@ class DistributedDataParallel(Module): ...@@ -324,6 +324,8 @@ class DistributedDataParallel(Module):
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs)
if not self.delay_allreduce: if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
...@@ -354,4 +356,4 @@ class DistributedDataParallel(Module): ...@@ -354,4 +356,4 @@ class DistributedDataParallel(Module):
self.callback_queued = False self.callback_queued = False
return self.module(*inputs, **kwargs) return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment