Commit 1c2ba890 authored by Thor Johnsen's avatar Thor Johnsen Committed by mcarilli
Browse files

Add option to turn on/off allreduce in DDP (useful for gradient accumulation) (#356)

parent 47e3367f
......@@ -215,6 +215,8 @@ class DistributedDataParallel(Module):
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.module = module
self.disable_allreduce = False
if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters():
......@@ -249,6 +251,12 @@ class DistributedDataParallel(Module):
del attrs['self.reduction_stream']
del attrs['self.reduction_event']
return attrs
def turn_on_allreduce(self):
self.disable_allreduce = False
def turn_off_allreduce(self):
self.disable_allreduce = True
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
......@@ -327,44 +335,45 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
# needs_refresh and callback_queued are both vulnerable states.
if not self.delay_allreduce and self.needs_refresh:
# Use the backward pass to build the bucket structure on the fly.
active_i = self.param_id_to_active_i[id(param)]
# Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()]
if not self.disable_allreduce:
if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
# needs_refresh and callback_queued are both vulnerable states.
if not self.delay_allreduce and self.needs_refresh:
# Use the backward pass to build the bucket structure on the fly.
active_i = self.param_id_to_active_i[id(param)]
# Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False
if self.custom_allreduce_triggers:
if id(param) in self.allreduce_trigger_params:
ship_tmp_bucket = True
else:
self.tmp_numels[current_type] += param.numel()
if self.tmp_numels[current_type] >= self.message_size:
ship_tmp_bucket = True
# To consider: If custom_allreduce_triggers are in use, ship all
# tmp_buckets, not just tmp_buckets[current_type].
if ship_tmp_bucket:
self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0
if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True
else:
if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True
self.comm_ready_buckets(param)
self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False
if self.custom_allreduce_triggers:
if id(param) in self.allreduce_trigger_params:
ship_tmp_bucket = True
else:
self.tmp_numels[current_type] += param.numel()
if self.tmp_numels[current_type] >= self.message_size:
ship_tmp_bucket = True
# To consider: If custom_allreduce_triggers are in use, ship all
# tmp_buckets, not just tmp_buckets[current_type].
if ship_tmp_bucket:
self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0
if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True
else:
if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True
self.comm_ready_buckets(param)
grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc)
......@@ -422,7 +431,7 @@ class DistributedDataParallel(Module):
# training script, and overwritten in the next forward pass. So it's harmless.
if self.retain_allreduce_buffers:
self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i)
......@@ -469,38 +478,39 @@ class DistributedDataParallel(Module):
def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs)
if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
# Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ((not self.active_params) or
(len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True
if self.needs_refresh:
self.active_i_buckets = []
self.buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0]
self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0
self.ready_buckets_not_reduced = set()
if not self.disable_allreduce:
if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
# Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ((not self.active_params) or
(len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True
if self.needs_refresh:
self.active_i_buckets = []
self.buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0]
self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0
self.ready_buckets_not_reduced = set()
self.active_params = param_list
self.active_params = param_list
self.callback_queued = False
self.callback_queued = False
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment