Commit 137d822b authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Better shortcut for shared_param = True

parent fb075b86
...@@ -79,17 +79,19 @@ class DistributedDataParallel(Module): ...@@ -79,17 +79,19 @@ class DistributedDataParallel(Module):
def create_hooks(self): def create_hooks(self):
#all reduce gradient hook #all reduce gradient hook
def allreduce_params(): def allreduce_params():
if(self.needs_reduction): if not self.needs_reduction:
self.needs_reduction = False
self.needs_refresh = False
else:
return return
self.needs_reduction = False
#parameter ordering refresh
if self.needs_refresh and not self.shared_param:
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record]
self.needs_refresh = False
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce) flat_dist_call(grads, dist.all_reduce)
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record]
def flush_buckets(): def flush_buckets():
if not self.needs_reduction: if not self.needs_reduction:
...@@ -184,12 +186,12 @@ class DistributedDataParallel(Module): ...@@ -184,12 +186,12 @@ class DistributedDataParallel(Module):
#Force needs_refresh to True if there are shared params #Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe #this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model. #for shared parameters in the model.
if self.shared_param: if not self.param_refs or self.shared_param:
self.param_refs = [] self.needs_refresh = True
else:
self.needs_refresh = True if not self.param_refs else any( self.needs_refresh = any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)] [param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]
) )
if self.needs_refresh: if self.needs_refresh:
self.record = [] self.record = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment