Unverified Commit 53e1b61a authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Fix param freezing (#47)

* Fix appears to work in Tomasz's example.

* Somehow shared_param got de-enabled again?
parent ed47ebff
......@@ -133,26 +133,27 @@ class DistributedDataParallel(Module):
# Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36
if(hasattr(dist, "get_backend")):
self._backend = dist.get_backend()
self.backend_enum_holder = dist.DistBackend
self._backend = dist.get_backend()
self.backend_enum_holder = dist.DistBackend
else:
self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend
self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.shared_param = shared_param
self.message_size = message_size
# reference to last iterations parameters to see if anything has changed
self.param_refs = []
# Will hold [param for param in self.module.parameters() if param.requires_grad]
# aka, the active paramters this iteration. The ordering of this list will be
# the same across all processes.
self.active_params = []
self.reduction_stream = torch.cuda.Stream()
self.module = module
self.param_list = list(self.module.parameters())
if self._backend == self.backend_enum_holder.NCCL:
for param in self.param_list:
for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.record = []
......@@ -160,10 +161,12 @@ class DistributedDataParallel(Module):
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream()
def __getstate__(self):
attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL:
......@@ -183,6 +186,11 @@ class DistributedDataParallel(Module):
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record]
# As before, self.record stores a list of indexes into self.active_params.
# param_id_to_record_i is a map from each active param's id to its slot in
# self.record.
self.param_id_to_record_i = {id(self.active_params[a]) : i
for i, a in enumerate(self.record)}
self.needs_refresh = False
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
......@@ -195,7 +203,7 @@ class DistributedDataParallel(Module):
grads = []
for i in range(self.ready_end, len(self.param_state)):
param = self.param_refs[self.record[i]]
param = self.active_params[self.record[i]]
if param.grad is not None:
grads.append(param.grad.data)
grads = [param.grad.data for param in self.ready_params] + grads
......@@ -208,42 +216,41 @@ class DistributedDataParallel(Module):
torch.cuda.current_stream().wait_stream(self.reduction_stream)
for param_i, param in enumerate([p for p in self.module.parameters() if p.requires_grad]):
def wrapper(param_i):
def allreduce_hook(*unused):
if self.needs_refresh:
self.record.append(param_i)
Variable._execution_engine.queue_callback(allreduce_params)
else:
Variable._execution_engine.queue_callback(flush_buckets)
self.comm_ready_buckets(self.record.index(param_i))
if param.requires_grad:
for param in self.module.parameters():
if param.requires_grad:
def wrapper(param):
def allreduce_hook(*unused):
if self.needs_refresh:
self.record.append(self.param_id_to_active_i[id(param)])
Variable._execution_engine.queue_callback(allreduce_params)
else:
Variable._execution_engine.queue_callback(flush_buckets)
# param_id_to_record_i handily enables us to replace the
# O(N) self.record.index(param_i) call with an O(1) dict lookup.
self.comm_ready_buckets(self.param_id_to_record_i[id(param)])
param.register_hook(allreduce_hook)
wrapper(param_i)
wrapper(param)
def comm_ready_buckets(self, param_ind):
def comm_ready_buckets(self, record_i):
if self.param_state[param_ind] != 0:
if self.param_state[record_i] != 0:
raise RuntimeError("Error: Your model uses shared parameters, DDP flag shared_params must be set to True in initialization.")
if self.param_state[self.ready_end] == 0:
self.param_state[param_ind] = 1
self.param_state[record_i] = 1
return
while self.ready_end < len(self.param_state) and self.param_state[self.ready_end] == 1:
self.ready_params.append(self.param_refs[self.record[self.ready_end]])
self.ready_params.append(self.active_params[self.record[self.ready_end]])
self.ready_numel += self.ready_params[-1].numel()
self.ready_end += 1
if self.ready_numel < self.message_size:
self.param_state[param_ind] = 1
self.param_state[record_i] = 1
return
grads = [param.grad.data for param in self.ready_params]
......@@ -271,7 +278,8 @@ class DistributedDataParallel(Module):
self.param_state[i] = 2
self.ready_params.pop(0)
self.param_state[param_ind] = 1
self.param_state[record_i] = 1
def forward(self, *inputs, **kwargs):
......@@ -281,17 +289,19 @@ class DistributedDataParallel(Module):
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ( (not self.param_refs) or
if ( (not self.active_params) or
self.shared_param or
(len(param_list) != len(self.param_refs)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]) ):
(len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)]) ):
self.needs_refresh = True
if self.needs_refresh:
self.record = []
# Map from each param's id to its index in the list of active parameters.
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list
self.active_params = param_list
self.needs_reduction = True
self.ready_start = 0
......
......@@ -122,7 +122,7 @@ def main():
model = network_to_half(model)
if args.distributed:
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf
model = DDP(model)
model = DDP(model, shared_param=True)
global model_params, master_params
if args.fp16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment