"tests/vscode:/vscode.git/clone" did not exist on "ad15947f0ea9b34e15157dfad65b25f3a98e9ac8"
Commit 1c2ba890 authored by Thor Johnsen's avatar Thor Johnsen Committed by mcarilli
Browse files

Add option to turn on/off allreduce in DDP (useful for gradient accumulation) (#356)

parent 47e3367f
...@@ -216,6 +216,8 @@ class DistributedDataParallel(Module): ...@@ -216,6 +216,8 @@ class DistributedDataParallel(Module):
self.module = module self.module = module
self.disable_allreduce = False
if self._backend == self.backend_enum_holder.NCCL: if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters(): for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
...@@ -250,6 +252,12 @@ class DistributedDataParallel(Module): ...@@ -250,6 +252,12 @@ class DistributedDataParallel(Module):
del attrs['self.reduction_event'] del attrs['self.reduction_event']
return attrs return attrs
def turn_on_allreduce(self):
self.disable_allreduce = False
def turn_off_allreduce(self):
self.disable_allreduce = True
# Broadcast rank 0's bucket structure across all processes, and have all processes # Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match. # regenerate their bucket structures to match.
def sync_bucket_structure(self): def sync_bucket_structure(self):
...@@ -327,6 +335,7 @@ class DistributedDataParallel(Module): ...@@ -327,6 +335,7 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if not self.disable_allreduce:
if self.delay_allreduce or self.needs_refresh: if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between # TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True? # each forward, e.g., backward passes with retain_graph=True?
...@@ -470,6 +479,7 @@ class DistributedDataParallel(Module): ...@@ -470,6 +479,7 @@ class DistributedDataParallel(Module):
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs) result = self.module(*inputs, **kwargs)
if not self.disable_allreduce:
if not self.delay_allreduce: if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment