Commit 73d4212d authored by Michael Carilli's avatar Michael Carilli
Browse files

Explicit control over number of allreduce streams for DDP

parent 070c7e96
...@@ -170,6 +170,7 @@ class DistributedDataParallel(Module): ...@@ -170,6 +170,7 @@ class DistributedDataParallel(Module):
retain_allreduce_buffers=False, retain_allreduce_buffers=False,
allreduce_always_fp32=False, allreduce_always_fp32=False,
allreduce_different_streams=False, allreduce_different_streams=False,
num_allreduce_streams=1,
gradient_average=True, gradient_average=True,
gradient_predivide_factor=1.0, gradient_predivide_factor=1.0,
gradient_average_split_factor=None, gradient_average_split_factor=None,
...@@ -197,6 +198,7 @@ class DistributedDataParallel(Module): ...@@ -197,6 +198,7 @@ class DistributedDataParallel(Module):
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.") raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
self.allreduce_different_streams = allreduce_different_streams self.allreduce_different_streams = allreduce_different_streams
self.num_allreduce_streams = num_allreduce_streams
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
...@@ -401,14 +403,14 @@ class DistributedDataParallel(Module): ...@@ -401,14 +403,14 @@ class DistributedDataParallel(Module):
def _stream_this_bucket(self, bucket_idx): def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams: if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx] return self.bucket_streams[bucket_idx%self.num_allreduce_streams]
else: else:
return self.bucket_streams[0] return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx): def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams: if self.allreduce_different_streams:
return self.bucket_events[bucket_idx] return self.bucket_events[bucket_idx%self.num_allreduce_streams]
else: else:
return self.bucket_events[0] return self.bucket_events[0]
...@@ -436,8 +438,8 @@ class DistributedDataParallel(Module): ...@@ -436,8 +438,8 @@ class DistributedDataParallel(Module):
if self.gradient_predivide_factor != 1.0: if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor) tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
if self.allreduce_different_streams and self.bucket_pgs: if self.allreduce_different_streams and not force_default_stream:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx]) dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx%self.num_allreduce_streams])
else: else:
dist.all_reduce(tensor_to_allreduce) dist.all_reduce(tensor_to_allreduce)
...@@ -579,8 +581,8 @@ class DistributedDataParallel(Module): ...@@ -579,8 +581,8 @@ class DistributedDataParallel(Module):
self.bucket_streams = [] self.bucket_streams = []
self.bucket_events = [] self.bucket_events = []
else: else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])] # self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] # for i in range(self.num_buckets)]
if not self.buckets: if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])] self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
...@@ -594,15 +596,15 @@ class DistributedDataParallel(Module): ...@@ -594,15 +596,15 @@ class DistributedDataParallel(Module):
bucket[i] = None bucket[i] = None
if self.allreduce_different_streams: if self.allreduce_different_streams:
if not self.bucket_pgs: if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_buckets)] self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]
for i, bg in enumerate(self.bucket_pgs): for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format( print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg))) dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams: if self.allreduce_different_streams:
if not self.bucket_streams: if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_buckets)] self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]
self.bucket_events = [torch.cuda.Event(enable_timing=False, sele.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_buckets)] blocking=False) for _ in range(self.num_allreduce_streams)]
else: else:
if not self.bucket_streams: if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()] self.bucket_streams = [torch.cuda.Stream()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment