"...text-generation-inference.git" did not exist on "17b7b75e652394379931c058a8c2db3a000b4225"
Commit 1c2ba890 authored by Thor Johnsen's avatar Thor Johnsen Committed by mcarilli
Browse files

Add option to turn on/off allreduce in DDP (useful for gradient accumulation) (#356)

parent 47e3367f
...@@ -215,6 +215,8 @@ class DistributedDataParallel(Module): ...@@ -215,6 +215,8 @@ class DistributedDataParallel(Module):
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.module = module self.module = module
self.disable_allreduce = False
if self._backend == self.backend_enum_holder.NCCL: if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters(): for param in self.module.parameters():
...@@ -249,6 +251,12 @@ class DistributedDataParallel(Module): ...@@ -249,6 +251,12 @@ class DistributedDataParallel(Module):
del attrs['self.reduction_stream'] del attrs['self.reduction_stream']
del attrs['self.reduction_event'] del attrs['self.reduction_event']
return attrs return attrs
def turn_on_allreduce(self):
self.disable_allreduce = False
def turn_off_allreduce(self):
self.disable_allreduce = True
# Broadcast rank 0's bucket structure across all processes, and have all processes # Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match. # regenerate their bucket structures to match.
...@@ -327,44 +335,45 @@ class DistributedDataParallel(Module): ...@@ -327,44 +335,45 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if self.delay_allreduce or self.needs_refresh: if not self.disable_allreduce:
# TODO: How do we want to handle multiple backward passes between if self.delay_allreduce or self.needs_refresh:
# each forward, e.g., backward passes with retain_graph=True? # TODO: How do we want to handle multiple backward passes between
# needs_refresh and callback_queued are both vulnerable states. # each forward, e.g., backward passes with retain_graph=True?
if not self.delay_allreduce and self.needs_refresh: # needs_refresh and callback_queued are both vulnerable states.
# Use the backward pass to build the bucket structure on the fly. if not self.delay_allreduce and self.needs_refresh:
active_i = self.param_id_to_active_i[id(param)] # Use the backward pass to build the bucket structure on the fly.
active_i = self.param_id_to_active_i[id(param)]
# Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()] # Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i) self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False ship_tmp_bucket = False
if self.custom_allreduce_triggers: if self.custom_allreduce_triggers:
if id(param) in self.allreduce_trigger_params: if id(param) in self.allreduce_trigger_params:
ship_tmp_bucket = True ship_tmp_bucket = True
else: else:
self.tmp_numels[current_type] += param.numel() self.tmp_numels[current_type] += param.numel()
if self.tmp_numels[current_type] >= self.message_size: if self.tmp_numels[current_type] >= self.message_size:
ship_tmp_bucket = True ship_tmp_bucket = True
# To consider: If custom_allreduce_triggers are in use, ship all # To consider: If custom_allreduce_triggers are in use, ship all
# tmp_buckets, not just tmp_buckets[current_type]. # tmp_buckets, not just tmp_buckets[current_type].
if ship_tmp_bucket: if ship_tmp_bucket:
self.active_i_buckets.append(self.tmp_buckets[current_type]) self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = [] self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0 self.tmp_numels[current_type] = 0
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True self.callback_queued = True
else: else:
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue) Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True self.callback_queued = True
self.comm_ready_buckets(param) self.comm_ready_buckets(param)
grad_acc.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
...@@ -422,7 +431,7 @@ class DistributedDataParallel(Module): ...@@ -422,7 +431,7 @@ class DistributedDataParallel(Module):
# training script, and overwritten in the next forward pass. So it's harmless. # training script, and overwritten in the next forward pass. So it's harmless.
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
self.allreduce_buffers = [None for _ in range(len(split_buckets))] self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets): for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i) allreduced = self.allreduce_maybe_retain(bucket, i)
...@@ -469,38 +478,39 @@ class DistributedDataParallel(Module): ...@@ -469,38 +478,39 @@ class DistributedDataParallel(Module):
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs) result = self.module(*inputs, **kwargs)
if not self.delay_allreduce: if not self.disable_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
# Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params # Conditions under which to refresh self.record
# in backward has the authority to set needs_refresh to False. # Forward has the authority to set needs_refresh to True, but only allreduce_params
# Parentheses are not necessary for correct order of operations, but make the intent clearer. # in backward has the authority to set needs_refresh to False.
if ((not self.active_params) or # Parentheses are not necessary for correct order of operations, but make the intent clearer.
(len(param_list) != len(self.active_params)) or if ((not self.active_params) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])): (len(param_list) != len(self.active_params)) or
self.needs_refresh = True any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True
if self.needs_refresh:
self.active_i_buckets = [] if self.needs_refresh:
self.buckets = [] self.active_i_buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets] self.buckets = []
self.tmp_numels = [0, 0, 0] self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.bucket_sizes = [] self.tmp_numels = [0, 0, 0]
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.bucket_sizes = []
self.param_id_to_bucket = {} self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
else: self.param_id_to_bucket = {}
self.buckets = [[None for _ in range(self.bucket_sizes[i])] else:
for i in range(self.num_buckets)] self.buckets = [[None for _ in range(self.bucket_sizes[i])]
self.buckets_ready_size = [0 for i in range(self.num_buckets)] for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers): self.buckets_ready_size = [0 for i in range(self.num_buckets)]
self.allreduce_buffers = [None for _ in range(self.num_buckets)] if(self.retain_allreduce_buffers):
self.next_bucket = 0 self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.ready_buckets_not_reduced = set() self.next_bucket = 0
self.ready_buckets_not_reduced = set()
self.active_params = param_list self.active_params = param_list
self.callback_queued = False self.callback_queued = False
return result return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment