Commit 8521bb22 authored by Michael Carilli's avatar Michael Carilli
Browse files

Patching in changes to enable multiple allreduces in flight

parent 61b8a0fd
......@@ -169,9 +169,11 @@ class DistributedDataParallel(Module):
allreduce_trigger_params=None,
retain_allreduce_buffers=False,
allreduce_always_fp32=False,
allreduce_different_streams=False,
gradient_average=True,
gradient_predivide_factor=1.0,
gradient_average_split_factor=None):
gradient_average_split_factor=None,
prof=False):
super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around
......@@ -189,6 +191,13 @@ class DistributedDataParallel(Module):
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.prof = prof
if allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
self.allreduce_different_streams = allreduce_different_streams
if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
......@@ -213,8 +222,8 @@ class DistributedDataParallel(Module):
self.delay_allreduce = delay_allreduce
self.message_size = message_size
self.reduction_stream = torch.cuda.Stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.bucket_streams = []
self.bucket_events = []
self.module = module
......@@ -241,15 +250,21 @@ class DistributedDataParallel(Module):
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
if allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce:
self.needs_refresh = True
self.bucket_streams = []
self.bucket_events = []
def __getstate__(self):
attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream']
del attrs['self.reduction_event']
del attrs['self.bucket_streams']
del attrs['self.bucket_events']
return attrs
# Broadcast rank 0's bucket structure across all processes, and have all processes
......@@ -307,8 +322,9 @@ class DistributedDataParallel(Module):
def overlapping_backward_epilogue():
self.reduction_stream.record_event(self.reduction_event)
torch.cuda.current_stream().wait_event(self.reduction_event)
for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
# Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets:
......@@ -329,6 +345,9 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook")
if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
......@@ -368,11 +387,30 @@ class DistributedDataParallel(Module):
self.comm_ready_buckets(param)
if self.prof:
torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc)
wrapper(param)
def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx]
else:
return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_events[bucket_idx]
else:
return self.bucket_events[0]
def allreduce_bucket(self, bucket):
tensor = flatten(bucket)
......@@ -384,6 +422,9 @@ class DistributedDataParallel(Module):
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
if self.allreduce_different_streams and self.bucket_pgs:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx])
else:
dist.all_reduce(tensor_to_allreduce)
if self.gradient_average:
......@@ -396,7 +437,7 @@ class DistributedDataParallel(Module):
def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
allreduced = self.allreduce_bucket(bucket)
allreduced = self.allreduce_bucket(bucket, bucket_idx)
if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled "
......@@ -432,6 +473,8 @@ class DistributedDataParallel(Module):
def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream())
if self.prof:
torch.cuda.nvtx.range_push("comm_ready_buckets")
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
......@@ -444,9 +487,11 @@ class DistributedDataParallel(Module):
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket:
torch.cuda.current_stream().record_event(self.reduction_event)
self.reduction_stream.wait_event(self.reduction_event)
with torch.cuda.stream(self.reduction_stream):
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1
......@@ -468,10 +513,16 @@ class DistributedDataParallel(Module):
else:
self.ready_buckets_not_reduced.add(bucket_idx)
if self.prof:
torch.cuda.nvtx.range_pop()
def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs)
if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
......@@ -492,9 +543,40 @@ class DistributedDataParallel(Module):
self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
self.bucket_pgs = []
self.bucket_streams = []
self.bucket_events = []
else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_buckets)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_buckets)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_buckets)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)]
......@@ -505,4 +587,7 @@ class DistributedDataParallel(Module):
self.callback_queued = False
if self.prof:
torch.cuda.nvtx.range_pop()
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment