Commit ed14f39c authored by Michael Carilli's avatar Michael Carilli
Browse files

Fixing needs_refresh logic to allow multiple forwards between each backward

parent 586c507e
......@@ -133,7 +133,7 @@ class DistributedDataParallel(Module):
self.shared_param = shared_param
self.message_size = message_size
#reference to last iterations parameters to see if anything has changed
# reference to last iterations parameters to see if anything has changed
self.param_refs = []
self.reduction_stream = torch.cuda.Stream()
......@@ -162,13 +162,13 @@ class DistributedDataParallel(Module):
def create_hooks(self):
#all reduce gradient hook
# all reduce gradient hook
def allreduce_params():
if not self.needs_reduction:
return
self.needs_reduction = False
#parameter ordering refresh
# parameter ordering refresh
if self.needs_refresh and not self.shared_param:
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
......@@ -267,21 +267,18 @@ class DistributedDataParallel(Module):
param_list = [param for param in self.module.parameters() if param.requires_grad]
#Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model.
#Parentheses are not necessary for correct order of operations, but make the intent clearer.
if (not self.param_refs) or self.shared_param:
# Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
if (not self.param_refs) or
self.shared_param or
(len(param_list) != len(self.param_refs)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]):
self.needs_refresh = True
else:
self.needs_refresh = (
(len(param_list) != len(self.param_refs)) or any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]))
if self.needs_refresh:
self.record = []
if self.needs_refresh:
self.record = []
self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment