Commit 71957836 authored by Michael Carilli's avatar Michael Carilli
Browse files

Option to preinitialize allreduce communicators

parent fedfe0d7
...@@ -169,8 +169,8 @@ class DistributedDataParallel(Module): ...@@ -169,8 +169,8 @@ class DistributedDataParallel(Module):
allreduce_trigger_params=None, allreduce_trigger_params=None,
retain_allreduce_buffers=False, retain_allreduce_buffers=False,
allreduce_always_fp32=False, allreduce_always_fp32=False,
allreduce_different_streams=False,
num_allreduce_streams=1, num_allreduce_streams=1,
allreduce_communicators=None,
gradient_average=True, gradient_average=True,
gradient_predivide_factor=1.0, gradient_predivide_factor=1.0,
gradient_average_split_factor=None, gradient_average_split_factor=None,
...@@ -197,8 +197,13 @@ class DistributedDataParallel(Module): ...@@ -197,8 +197,13 @@ class DistributedDataParallel(Module):
if allreduce_different_streams and delay_allreduce: if allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.") raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
self.allreduce_different_streams = allreduce_different_streams self.allreduce_different_streams = (num_allreduce_streams > 1)
self.num_allreduce_streams = num_allreduce_streams self.num_allreduce_streams = num_allreduce_streams
self.allreduce_communicators = allreduce_communicators
if self.allreduce_communicators:
assert len(allreduce_communicators[0]) == num_allreduce_streams
assert len(allreduce_communicators[0]) == len(allreduce_communicators[1])
assert allreduce_different_streams
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
...@@ -594,22 +599,28 @@ class DistributedDataParallel(Module): ...@@ -594,22 +599,28 @@ class DistributedDataParallel(Module):
b, len(buckets[b]), self.bucket_sizes[b]) b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)): for i in range(len(bucket)):
bucket[i] = None bucket[i] = None
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
if self.allreduce_communicators:
self.bucket_pgs = self.allreduce_communicators[0]
self.bucket_streams = self.allreduce_communicators[1]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)] self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers): if(self.retain_allreduce_buffers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment