Commit 25ac9897 authored by Michael Carilli's avatar Michael Carilli
Browse files

Moving flat allreduce buffer creation to main stream

parent b8965a78
...@@ -222,6 +222,8 @@ class DistributedDataParallel(Module): ...@@ -222,6 +222,8 @@ class DistributedDataParallel(Module):
self.delay_allreduce = delay_allreduce self.delay_allreduce = delay_allreduce
self.message_size = message_size self.message_size = message_size
self.main_stream = torch.cuda.current_stream()
self.bucket_streams = [] self.bucket_streams = []
self.bucket_events = [] self.bucket_events = []
...@@ -411,33 +413,64 @@ class DistributedDataParallel(Module): ...@@ -411,33 +413,64 @@ class DistributedDataParallel(Module):
return self.bucket_events[0] return self.bucket_events[0]
def allreduce_bucket(self, bucket, bucket_idx): def allreduce_bucket(self, bucket, bucket_idx, force_default_stream):
tensor = flatten(bucket) tensor = flatten(bucket)
tensor_to_allreduce = tensor if force_default_stream:
bucket_stream = self.main_stream
else:
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
if self.allreduce_always_fp32: tensor_to_allreduce = tensor
tensor_to_allreduce = tensor.float()
if self.gradient_predivide_factor != 1.0: if self.allreduce_always_fp32:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor) tensor_to_allreduce = tensor.float()
if self.allreduce_different_streams and self.bucket_pgs: if self.gradient_predivide_factor != 1.0:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx]) tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
else:
dist.all_reduce(tensor_to_allreduce) if self.allreduce_different_streams and self.bucket_pgs:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx])
else:
dist.all_reduce(tensor_to_allreduce)
if self.gradient_average:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)
if self.gradient_average: if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size) tensor.copy_(tensor_to_allreduce)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce: if not self.retain_allreduce_buffers:
tensor.copy_(tensor_to_allreduce) if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(tensor, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(tensor, bucket)):
buf.copy_(synced)
# Any subsequent operations that we do on tensor after allreduce_bucket returns must
# be synced on bucket_stream anyway.
# Also, we maintain a live reference to the returned tensor in allreduce_buffers.
# But this doesn't hurt.
tensor.record_stream(bucket_stream)
# torch.cuda.synchronize()
return tensor return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx=-1): def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False):
allreduced = self.allreduce_bucket(bucket, bucket_idx) allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream)
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None: if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
...@@ -445,19 +478,15 @@ class DistributedDataParallel(Module): ...@@ -445,19 +478,15 @@ class DistributedDataParallel(Module):
self.allreduce_buffers[bucket_idx] = allreduced self.allreduce_buffers[bucket_idx] = allreduced
for view, grad in zip(unflatten(allreduced, bucket), bucket): for view, grad in zip(unflatten(allreduced, bucket), bucket):
grad.data = view grad.data = view
else: # for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
if multi_tensor_applier.available: # buf.copy_(synced)
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
def allreduce_fallback(self): def allreduce_fallback(self):
for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
grads = [param.grad for param in self.module.parameters() if param.grad is not None] grads = [param.grad for param in self.module.parameters() if param.grad is not None]
else: else:
...@@ -472,7 +501,7 @@ class DistributedDataParallel(Module): ...@@ -472,7 +501,7 @@ class DistributedDataParallel(Module):
self.allreduce_buffers = [None for _ in range(len(split_buckets))] self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets): for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i) allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True)
def comm_ready_buckets(self, param): def comm_ready_buckets(self, param):
...@@ -496,29 +525,24 @@ class DistributedDataParallel(Module): ...@@ -496,29 +525,24 @@ class DistributedDataParallel(Module):
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]: if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket: if bucket_idx == self.next_bucket:
bucket_stream = self._stream_this_bucket(bucket_idx) self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event) self.next_bucket += 1
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream): # Reversing upstream's logic here, because we constructed our buckets based on
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx) # the order things were received during backward.
if len(self.ready_buckets_not_reduced) > 0:
self.next_bucket += 1 sorted_todo = sorted(self.ready_buckets_not_reduced)
for i in sorted_todo:
# Reversing upstream's logic here, because we constructed our buckets based on # Nothing can be reduced now
# the order things were received during backward. if i > self.next_bucket:
if len(self.ready_buckets_not_reduced) > 0: break
sorted_todo = sorted(self.ready_buckets_not_reduced) elif i == self.next_bucket:
for i in sorted_todo: self.allreduce_maybe_retain(self.buckets[i], i)
# Nothing can be reduced now self.ready_buckets_not_reduced.remove(i)
if i > self.next_bucket: self.next_bucket += 1
break else:
elif i == self.next_bucket: raise ValueError("i should always be >= next_bucket")
self.allreduce_maybe_retain(self.buckets[i], i)
self.ready_buckets_not_reduced.remove(i)
self.next_bucket += 1
else:
raise ValueError("i should always be >= next_bucket")
else: else:
self.ready_buckets_not_reduced.add(bucket_idx) self.ready_buckets_not_reduced.add(bucket_idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment