Commit 72bce160 authored by Michael Carilli's avatar Michael Carilli
Browse files

allreduce_different_streams is now hidden

parent 3b4a0a23
...@@ -200,10 +200,10 @@ class DistributedDataParallel(Module): ...@@ -200,10 +200,10 @@ class DistributedDataParallel(Module):
if self.allreduce_communicators: if self.allreduce_communicators:
assert len(allreduce_communicators[0]) == num_allreduce_streams assert len(allreduce_communicators[0]) == num_allreduce_streams
assert len(allreduce_communicators[0]) == len(allreduce_communicators[1]) assert len(allreduce_communicators[0]) == len(allreduce_communicators[1])
assert allreduce_different_streams assert self.allreduce_different_streams
if self.allreduce_different_streams and delay_allreduce: if self.allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.") raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
...@@ -259,8 +259,8 @@ class DistributedDataParallel(Module): ...@@ -259,8 +259,8 @@ class DistributedDataParallel(Module):
def __setstate__(self, state): def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state) super(DistributedDataParallel, self).__setstate__(state)
if allreduce_different_streams and delay_allreduce: if self.allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.") raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce: if self.delay_allreduce:
self.needs_refresh = True self.needs_refresh = True
......
...@@ -35,8 +35,9 @@ class Model(Module): ...@@ -35,8 +35,9 @@ class Model(Module):
model = Model() model = Model()
# model = DDP(model, message_size=1, gradient_predivide_factor=8.0) # model = DDP(model, message_size=1, gradient_predivide_factor=8.0)
model = DDP(model, delay_allreduce=True) # model = DDP(model, delay_allreduce=True)
# model = DDP(model, message_size=1, allreduce_trigger_params=[model.b]) # model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])
model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3)
x = torch.cuda.FloatTensor(4096*4096) x = torch.cuda.FloatTensor(4096*4096)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment