Commit 25ac9897 authored by Michael Carilli's avatar Michael Carilli
Browse files

Moving flat allreduce buffer creation to main stream

parent b8965a78
......@@ -222,6 +222,8 @@ class DistributedDataParallel(Module):
self.delay_allreduce = delay_allreduce
self.message_size = message_size
self.main_stream = torch.cuda.current_stream()
self.bucket_streams = []
self.bucket_events = []
......@@ -411,9 +413,21 @@ class DistributedDataParallel(Module):
return self.bucket_events[0]
def allreduce_bucket(self, bucket, bucket_idx):
def allreduce_bucket(self, bucket, bucket_idx, force_default_stream):
tensor = flatten(bucket)
if force_default_stream:
bucket_stream = self.main_stream
else:
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
......@@ -433,11 +447,30 @@ class DistributedDataParallel(Module):
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
if not self.retain_allreduce_buffers:
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(tensor, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(tensor, bucket)):
buf.copy_(synced)
# Any subsequent operations that we do on tensor after allreduce_bucket returns must
# be synced on bucket_stream anyway.
# Also, we maintain a live reference to the returned tensor in allreduce_buffers.
# But this doesn't hurt.
tensor.record_stream(bucket_stream)
# torch.cuda.synchronize()
return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
allreduced = self.allreduce_bucket(bucket, bucket_idx)
def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False):
allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream)
if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled "
......@@ -445,19 +478,15 @@ class DistributedDataParallel(Module):
self.allreduce_buffers[bucket_idx] = allreduced
for view, grad in zip(unflatten(allreduced, bucket), bucket):
grad.data = view
else:
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
# for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
# buf.copy_(synced)
def allreduce_fallback(self):
for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
if self.retain_allreduce_buffers:
grads = [param.grad for param in self.module.parameters() if param.grad is not None]
else:
......@@ -472,7 +501,7 @@ class DistributedDataParallel(Module):
self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i)
allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True)
def comm_ready_buckets(self, param):
......@@ -496,11 +525,6 @@ class DistributedDataParallel(Module):
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket:
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment