Commit 343590a1 authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Revert "Revert distributed to simpler version until backward hooks are fixed."

This reverts commit 47ac5c2b.
parent 61c1e160
......@@ -4,6 +4,7 @@ import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True
buckets = {}
......@@ -28,79 +29,179 @@ def flat_dist_call(tensors, call, extra_args=None):
call(coalesced)
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
class DistributedDataParallel(Module):
"""
:class:`DistributedDataParallel` is a simpler version of upstream :class:`
DistributedDataParallel`. Its usage is designed to be used in conjunction with
apex.parallel.multiproc.py. It assumes that your run is using multiprocess with
1 GPU/process, that the model is on the correct device, and that
torch.set_device has been used to set the device. Parameters are broadcasted
DistributedDataParallel` that is optimized for use with NCCL. Its usage is designed
to be used in conjunction with apex.parallel.multiproc.py. It assumes that your run
is using multiprocess with 1 GPU/process, that the model is on the correct device,
and that torch.set_device has been used to set the device. Parameters are broadcasted
to the other processes on initialization of DistributedDataParallel, and will be
allreduced in buckets durring the backward pass.
See https://github.com/csarofeen/examples/tree/apex/distributed for detailed usage.
Args:
module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 10000000): Minimum number of elements in a communication bucket.
message_size (Default = 100e6): Minimum number of elements in a communication bucket.
shared_param (Default = False): If your model uses shared parameters this must be true,
it will disable bucketing of parameters which is necessary to avoid race conditions.
"""
def __init__(self, module):
def __init__(self, module, message_size=100000000, shared_param=False):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
self.shared_param = shared_param
self.message_size = message_size
#reference to last iterations parameters to see if anything has changed
self.param_refs = []
self.reduction_stream = torch.cuda.Stream()
self.module = module
param_list = [param for param in self.module.state_dict().values() if torch.is_tensor(param)]
self.param_list = list(self.module.parameters())
if dist._backend == dist.dist_backend.NCCL:
for param in param_list:
for param in self.param_list:
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
#broadcast parameters
flat_dist_call(param_list, dist.broadcast, (0,) )
self.record = []
self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
def create_hooks(self):
#all reduce gradient hook
def allreduce_params():
if(self.needs_reduction):
self.needs_reduction = False
self.needs_refresh = False
else:
return
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record]
def flush_buckets():
if not self.needs_reduction:
return
self.needs_reduction = False
grads = []
for i in range(self.ready_end, len(self.param_state)):
param = self.param_refs[self.record[i]]
if param.grad is not None:
grads.append(param.grad.data)
grads = [param.grad.data for param in self.ready_params] + grads
if(len(grads)>0):
orig_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.reduction_stream):
self.reduction_stream.wait_stream(orig_stream)
flat_dist_call(grads, dist.all_reduce)
torch.cuda.current_stream().wait_stream(self.reduction_stream)
for param_i, param in enumerate(list(self.module.parameters())):
def wrapper(param_i):
def allreduce_hook(*unused):
if self.needs_refresh:
self.record.append(param_i)
Variable._execution_engine.queue_callback(allreduce_params)
else:
Variable._execution_engine.queue_callback(flush_buckets)
self.comm_ready_buckets(self.record.index(param_i))
if param.requires_grad:
param.register_hook(allreduce_hook)
wrapper(param_i)
def comm_ready_buckets(self, param_ind):
if self.param_state[param_ind] != 0:
raise RuntimeError("Error: Your model uses shared parameters, DDP flag shared_params must be set to True in initialization.")
if self.param_state[self.ready_end] == 0:
self.param_state[param_ind] = 1
return
while self.ready_end < len(self.param_state) and self.param_state[self.ready_end] == 1:
self.ready_params.append(self.param_refs[self.record[self.ready_end]])
self.ready_numel += self.ready_params[-1].numel()
self.ready_end += 1
if self.ready_numel < self.message_size:
self.param_state[param_ind] = 1
return
grads = [param.grad.data for param in self.ready_params]
bucket = []
bucket_inds = []
while grads:
bucket.append(grads.pop(0))
for param in list(self.module.parameters()):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
if param.requires_grad:
param.register_hook(allreduce_hook)
cumm_size = 0
for ten in bucket:
cumm_size += ten.numel()
if cumm_size < self.message_size:
continue
evt = torch.cuda.Event()
evt.record(torch.cuda.current_stream())
evt.wait(stream=self.reduction_stream)
with torch.cuda.stream(self.reduction_stream):
flat_dist_call(bucket, dist.all_reduce)
for i in range(self.ready_start, self.ready_start+len(bucket)):
self.param_state[i] = 2
self.ready_params.pop(0)
self.param_state[param_ind] = 1
def forward(self, *inputs, **kwargs):
param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
#Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model.
if self.shared_param:
self.param_refs = []
self.needs_refresh = True if not self.param_refs else any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]
)
if self.needs_refresh:
self.record = []
self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list
self.needs_reduction = True
return self.module(*inputs, **kwargs)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
if len(buffers) > 0:
# cross-node buffer sync
flat_buffers = _flatten_dense_tensors(buffers)
dist.broadcast(flat_buffers, 0)
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
buf.copy_(synced)
def train(self, mode=True):
# Clear NCCL communicator and CUDA event cache of the default group ID,
# These cache will be recreated at the later call. This is currently a
# work-around for a potential NCCL deadlock.
if dist._backend == dist.dist_backend.NCCL:
dist._clear_group_cache()
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
self.ready_start = 0
self.ready_end = 0
self.ready_params = []
self.ready_numel = 0
return self.module(*inputs, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment