Commit ed14f39c authored by Michael Carilli's avatar Michael Carilli
Browse files

Fixing needs_refresh logic to allow multiple forwards between each backward

parent 586c507e
...@@ -133,7 +133,7 @@ class DistributedDataParallel(Module): ...@@ -133,7 +133,7 @@ class DistributedDataParallel(Module):
self.shared_param = shared_param self.shared_param = shared_param
self.message_size = message_size self.message_size = message_size
#reference to last iterations parameters to see if anything has changed # reference to last iterations parameters to see if anything has changed
self.param_refs = [] self.param_refs = []
self.reduction_stream = torch.cuda.Stream() self.reduction_stream = torch.cuda.Stream()
...@@ -162,13 +162,13 @@ class DistributedDataParallel(Module): ...@@ -162,13 +162,13 @@ class DistributedDataParallel(Module):
def create_hooks(self): def create_hooks(self):
#all reduce gradient hook # all reduce gradient hook
def allreduce_params(): def allreduce_params():
if not self.needs_reduction: if not self.needs_reduction:
return return
self.needs_reduction = False self.needs_reduction = False
#parameter ordering refresh # parameter ordering refresh
if self.needs_refresh and not self.shared_param: if self.needs_refresh and not self.shared_param:
t_record = torch.cuda.IntTensor(self.record) t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0) dist.broadcast(t_record, 0)
...@@ -267,22 +267,19 @@ class DistributedDataParallel(Module): ...@@ -267,22 +267,19 @@ class DistributedDataParallel(Module):
param_list = [param for param in self.module.parameters() if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
# Conditions under which to refresh self.record
#Force needs_refresh to True if there are shared params # Forward has the authority to set needs_refresh to True, but only allreduce_params
#this will force it to always, only call flush_buckets which is safe # in backward has the authority to set needs_refresh to False.
#for shared parameters in the model. # Parentheses are not necessary for correct order of operations, but make the intent clearer.
#Parentheses are not necessary for correct order of operations, but make the intent clearer. if (not self.param_refs) or
if (not self.param_refs) or self.shared_param: self.shared_param or
(len(param_list) != len(self.param_refs)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]):
self.needs_refresh = True self.needs_refresh = True
else:
self.needs_refresh = (
(len(param_list) != len(self.param_refs)) or any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]))
if self.needs_refresh: if self.needs_refresh:
self.record = [] self.record = []
self.param_state = [0 for i in range(len(param_list))] self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list self.param_refs = param_list
self.needs_reduction = True self.needs_reduction = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment