Commit 8521bb22 authored by Michael Carilli's avatar Michael Carilli
Browse files

Patching in changes to enable multiple allreduces in flight

parent 61b8a0fd
...@@ -169,9 +169,11 @@ class DistributedDataParallel(Module): ...@@ -169,9 +169,11 @@ class DistributedDataParallel(Module):
allreduce_trigger_params=None, allreduce_trigger_params=None,
retain_allreduce_buffers=False, retain_allreduce_buffers=False,
allreduce_always_fp32=False, allreduce_always_fp32=False,
allreduce_different_streams=False,
gradient_average=True, gradient_average=True,
gradient_predivide_factor=1.0, gradient_predivide_factor=1.0,
gradient_average_split_factor=None): gradient_average_split_factor=None,
prof=False):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around # Backward/forward compatibility around
...@@ -189,6 +191,13 @@ class DistributedDataParallel(Module): ...@@ -189,6 +191,13 @@ class DistributedDataParallel(Module):
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.prof = prof
if allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
self.allreduce_different_streams = allreduce_different_streams
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
...@@ -213,8 +222,8 @@ class DistributedDataParallel(Module): ...@@ -213,8 +222,8 @@ class DistributedDataParallel(Module):
self.delay_allreduce = delay_allreduce self.delay_allreduce = delay_allreduce
self.message_size = message_size self.message_size = message_size
self.reduction_stream = torch.cuda.Stream() self.bucket_streams = []
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) self.bucket_events = []
self.module = module self.module = module
...@@ -241,15 +250,21 @@ class DistributedDataParallel(Module): ...@@ -241,15 +250,21 @@ class DistributedDataParallel(Module):
def __setstate__(self, state): def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state) super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream() if allreduce_different_streams and delay_allreduce:
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce:
self.needs_refresh = True
self.bucket_streams = []
self.bucket_events = []
def __getstate__(self): def __getstate__(self):
attrs = copy.copy(self.__dict__) attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL: if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream'] del attrs['self.bucket_streams']
del attrs['self.reduction_event'] del attrs['self.bucket_events']
return attrs return attrs
# Broadcast rank 0's bucket structure across all processes, and have all processes # Broadcast rank 0's bucket structure across all processes, and have all processes
...@@ -307,8 +322,9 @@ class DistributedDataParallel(Module): ...@@ -307,8 +322,9 @@ class DistributedDataParallel(Module):
def overlapping_backward_epilogue(): def overlapping_backward_epilogue():
self.reduction_stream.record_event(self.reduction_event) for stream, event in zip(self.bucket_streams, self.bucket_events):
torch.cuda.current_stream().wait_event(self.reduction_event) stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
# Sanity checks that all the buckets were kicked off # Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets: if self.next_bucket != self.num_buckets:
...@@ -329,6 +345,9 @@ class DistributedDataParallel(Module): ...@@ -329,6 +345,9 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook")
if self.delay_allreduce or self.needs_refresh: if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between # TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True? # each forward, e.g., backward passes with retain_graph=True?
...@@ -368,11 +387,30 @@ class DistributedDataParallel(Module): ...@@ -368,11 +387,30 @@ class DistributedDataParallel(Module):
self.comm_ready_buckets(param) self.comm_ready_buckets(param)
if self.prof:
torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
wrapper(param) wrapper(param)
def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx]
else:
return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_events[bucket_idx]
else:
return self.bucket_events[0]
def allreduce_bucket(self, bucket): def allreduce_bucket(self, bucket):
tensor = flatten(bucket) tensor = flatten(bucket)
...@@ -384,6 +422,9 @@ class DistributedDataParallel(Module): ...@@ -384,6 +422,9 @@ class DistributedDataParallel(Module):
if self.gradient_predivide_factor != 1.0: if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor) tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
if self.allreduce_different_streams and self.bucket_pgs:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx])
else:
dist.all_reduce(tensor_to_allreduce) dist.all_reduce(tensor_to_allreduce)
if self.gradient_average: if self.gradient_average:
...@@ -396,7 +437,7 @@ class DistributedDataParallel(Module): ...@@ -396,7 +437,7 @@ class DistributedDataParallel(Module):
def allreduce_maybe_retain(self, bucket, bucket_idx=-1): def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
allreduced = self.allreduce_bucket(bucket) allreduced = self.allreduce_bucket(bucket, bucket_idx)
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None: if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
...@@ -432,6 +473,8 @@ class DistributedDataParallel(Module): ...@@ -432,6 +473,8 @@ class DistributedDataParallel(Module):
def comm_ready_buckets(self, param): def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR. # Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream()) # self.reduction_stream.wait_stream(torch.cuda.current_stream())
if self.prof:
torch.cuda.nvtx.range_push("comm_ready_buckets")
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)] bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
...@@ -444,9 +487,11 @@ class DistributedDataParallel(Module): ...@@ -444,9 +487,11 @@ class DistributedDataParallel(Module):
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]: if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket: if bucket_idx == self.next_bucket:
torch.cuda.current_stream().record_event(self.reduction_event) bucket_stream = self._stream_this_bucket(bucket_idx)
self.reduction_stream.wait_event(self.reduction_event) bucket_event = self._event_this_bucket(bucket_idx)
with torch.cuda.stream(self.reduction_stream): torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx) self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1 self.next_bucket += 1
...@@ -468,10 +513,16 @@ class DistributedDataParallel(Module): ...@@ -468,10 +513,16 @@ class DistributedDataParallel(Module):
else: else:
self.ready_buckets_not_reduced.add(bucket_idx) self.ready_buckets_not_reduced.add(bucket_idx)
if self.prof:
torch.cuda.nvtx.range_pop()
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs) result = self.module(*inputs, **kwargs)
if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self.delay_allreduce: if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
...@@ -492,9 +543,40 @@ class DistributedDataParallel(Module): ...@@ -492,9 +543,40 @@ class DistributedDataParallel(Module):
self.bucket_sizes = [] self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {} self.param_id_to_bucket = {}
self.bucket_pgs = []
self.bucket_streams = []
self.bucket_events = []
else: else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])] self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_buckets)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_buckets)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_buckets)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)] self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers): if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)] self.allreduce_buffers = [None for _ in range(self.num_buckets)]
...@@ -505,4 +587,7 @@ class DistributedDataParallel(Module): ...@@ -505,4 +587,7 @@ class DistributedDataParallel(Module):
self.callback_queued = False self.callback_queued = False
if self.prof:
torch.cuda.nvtx.range_pop()
return result return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment