Commit 8521bb22 authored by Michael Carilli's avatar Michael Carilli
Browse files

Patching in changes to enable multiple allreduces in flight

parent 61b8a0fd
# Introduction # Introduction
This repository holds NVIDIA-maintained utilities to streamline This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch. mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually. Some of the code here will be included in upstream Pytorch eventually.
The intention of Apex is to make up-to-date utilities available to The intention of Apex is to make up-to-date utilities available to
users as quickly as possible. users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex) ## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
...@@ -29,7 +29,7 @@ different flags to `amp.initialize`. ...@@ -29,7 +29,7 @@ different flags to `amp.initialize`.
## 2. Distributed Training ## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to `apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training, `torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library. optimized for NVIDIA's NCCL communication library.
......
...@@ -133,7 +133,7 @@ def _initialize(models, optimizers, properties, num_losses=1): ...@@ -133,7 +133,7 @@ def _initialize(models, optimizers, properties, num_losses=1):
if not _amp_state.allow_incoming_model_not_fp32: if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models) check_params_fp32(models)
check_optimizers(optimizers) check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can # In the future, when FP16_Optimizer can be deprecated and master weights can
...@@ -163,7 +163,7 @@ def _initialize(models, optimizers, properties, num_losses=1): ...@@ -163,7 +163,7 @@ def _initialize(models, optimizers, properties, num_losses=1):
model.forward = patch_forward(model.forward) model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors # State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers: for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict()) optimizer.load_state_dict(optimizer.state_dict())
......
...@@ -44,7 +44,7 @@ def apply_flat_dist_call(bucket, call, extra_args=None): ...@@ -44,7 +44,7 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
if call is dist.all_reduce: if call is dist.all_reduce:
coalesced /= dist.get_world_size() coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, unflatten(coalesced, bucket)): for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced) buf.copy_(synced)
...@@ -54,7 +54,7 @@ def split_half_float_double(tensors): ...@@ -54,7 +54,7 @@ def split_half_float_double(tensors):
for i, dtype in enumerate(dtypes): for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype] bucket = [t for t in tensors if t.type() == dtype]
if bucket: if bucket:
buckets.append(bucket) buckets.append(bucket)
return buckets return buckets
def split_by_type(tensors): def split_by_type(tensors):
...@@ -69,12 +69,12 @@ def split_by_type(tensors): ...@@ -69,12 +69,12 @@ def split_by_type(tensors):
# flat_dist_call organizes 'tensors' by type. # flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None): def flat_dist_call(tensors, call, extra_args=None):
buckets = split_by_type(tensors) buckets = split_by_type(tensors)
for tp in buckets: for tp in buckets:
bucket = buckets[tp] bucket = buckets[tp]
apply_flat_dist_call(bucket, call, extra_args) apply_flat_dist_call(bucket, call, extra_args)
def extract_tensors(maybe_tensor, tensor_list): def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor): if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor) tensor_list.append(maybe_tensor)
...@@ -85,7 +85,7 @@ def extract_tensors(maybe_tensor, tensor_list): ...@@ -85,7 +85,7 @@ def extract_tensors(maybe_tensor, tensor_list):
except TypeError: except TypeError:
return return
class Reducer(object): class Reducer(object):
""" """
:class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters :class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters
...@@ -93,13 +93,13 @@ class Reducer(object): ...@@ -93,13 +93,13 @@ class Reducer(object):
Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce
parameters during ``backward()``. parameters during ``backward()``.
Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually. Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually.
This enables, for example, delaying the allreduce to be carried out every This enables, for example, delaying the allreduce to be carried out every
several iterations instead of every single iteration. several iterations instead of every single iteration.
Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces
over the number of participating processes. over the number of participating processes.
:class:`Reducer` is designed to work with the upstream launch utility script :class:`Reducer` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``. ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs. When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model. It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
...@@ -109,7 +109,7 @@ class Reducer(object): ...@@ -109,7 +109,7 @@ class Reducer(object):
Args: Args:
module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training. module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training.
""" """
def __init__(self, module_or_grads_list): def __init__(self, module_or_grads_list):
if isinstance(module_or_grads_list, Module): if isinstance(module_or_grads_list, Module):
self.module = module_or_grads_list self.module = module_or_grads_list
...@@ -119,26 +119,26 @@ class Reducer(object): ...@@ -119,26 +119,26 @@ class Reducer(object):
self.module = None self.module = None
self.grads = [] self.grads = []
extract_tensors(module_or_grads_list, self.grads) extract_tensors(module_or_grads_list, self.grads)
def reduce(self): def reduce(self):
if self.module: if self.module:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce) flat_dist_call(grads, dist.all_reduce)
else: else:
flat_dist_call(self.grads, dist.all_reduce) flat_dist_call(self.grads, dist.all_reduce)
class DistributedDataParallel(Module): class DistributedDataParallel(Module):
""" """
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables :class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are
allreduced and averaged over processes during ``backward()``. allreduced and averaged over processes during ``backward()``.
:class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by :class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by
overlapping communication with computation during ``backward()`` and bucketing smaller gradient overlapping communication with computation during ``backward()`` and bucketing smaller gradient
transfers to reduce the total number of transfers required. transfers to reduce the total number of transfers required.
:class:`DistributedDataParallel` is designed to work with the upstream launch utility script :class:`DistributedDataParallel` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``. ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs. When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model. It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
...@@ -161,20 +161,22 @@ class DistributedDataParallel(Module): ...@@ -161,20 +161,22 @@ class DistributedDataParallel(Module):
""" """
def __init__(self, def __init__(self,
module, module,
message_size=10000000, message_size=10000000,
delay_allreduce=False, delay_allreduce=False,
shared_param=None, shared_param=None,
allreduce_trigger_params=None, allreduce_trigger_params=None,
retain_allreduce_buffers=False, retain_allreduce_buffers=False,
allreduce_always_fp32=False, allreduce_always_fp32=False,
allreduce_different_streams=False,
gradient_average=True, gradient_average=True,
gradient_predivide_factor=1.0, gradient_predivide_factor=1.0,
gradient_average_split_factor=None): gradient_average_split_factor=None,
prof=False):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around # Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and # https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and
# https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86 # https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86
if hasattr(dist, "get_backend"): if hasattr(dist, "get_backend"):
...@@ -184,13 +186,20 @@ class DistributedDataParallel(Module): ...@@ -184,13 +186,20 @@ class DistributedDataParallel(Module):
else: else:
self.backend_enum_holder = dist.Backend self.backend_enum_holder = dist.Backend
else: else:
self._backend = dist._backend self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.prof = prof
if allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
self.allreduce_different_streams = allreduce_different_streams
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
if gradient_average_split_factor is not None: if gradient_average_split_factor is not None:
print("Warning: gradient_average_split_factor has been renamed to gradient_predivide_factor. For now, gradient_average_split_factor will also work, but please update to gradient_predivide_factor instead.") print("Warning: gradient_average_split_factor has been renamed to gradient_predivide_factor. For now, gradient_average_split_factor will also work, but please update to gradient_predivide_factor instead.")
...@@ -206,25 +215,25 @@ class DistributedDataParallel(Module): ...@@ -206,25 +215,25 @@ class DistributedDataParallel(Module):
self.custom_allreduce_triggers = False self.custom_allreduce_triggers = False
if allreduce_trigger_params is not None: if allreduce_trigger_params is not None:
if delay_allreduce: if delay_allreduce:
raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.") raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.")
self.custom_allreduce_triggers = True self.custom_allreduce_triggers = True
self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params]) self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params])
self.delay_allreduce = delay_allreduce self.delay_allreduce = delay_allreduce
self.message_size = message_size self.message_size = message_size
self.reduction_stream = torch.cuda.Stream() self.bucket_streams = []
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) self.bucket_events = []
self.module = module self.module = module
if self._backend == self.backend_enum_holder.NCCL: if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters(): for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.active_params = [] self.active_params = []
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0, self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1, "torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2} "torch.cuda.DoubleTensor" : 2}
...@@ -241,19 +250,25 @@ class DistributedDataParallel(Module): ...@@ -241,19 +250,25 @@ class DistributedDataParallel(Module):
def __setstate__(self, state): def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state) super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream() if allreduce_different_streams and delay_allreduce:
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce:
self.needs_refresh = True
self.bucket_streams = []
self.bucket_events = []
def __getstate__(self): def __getstate__(self):
attrs = copy.copy(self.__dict__) attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL: if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream'] del attrs['self.bucket_streams']
del attrs['self.reduction_event'] del attrs['self.bucket_events']
return attrs return attrs
# Broadcast rank 0's bucket structure across all processes, and have all processes # Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match. # regenerate their bucket structures to match.
def sync_bucket_structure(self): def sync_bucket_structure(self):
# Append leftover buckets # Append leftover buckets
for tmp_bucket in self.tmp_buckets: for tmp_bucket in self.tmp_buckets:
...@@ -263,8 +278,8 @@ class DistributedDataParallel(Module): ...@@ -263,8 +278,8 @@ class DistributedDataParallel(Module):
self.num_buckets = len(self.active_i_buckets) self.num_buckets = len(self.active_i_buckets)
self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets] self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] + info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes + self.bucket_sizes +
list(chain(*self.active_i_buckets))) list(chain(*self.active_i_buckets)))
dist.broadcast(info_tensor, 0) dist.broadcast(info_tensor, 0)
...@@ -272,27 +287,27 @@ class DistributedDataParallel(Module): ...@@ -272,27 +287,27 @@ class DistributedDataParallel(Module):
info = [int(entry) for entry in info_tensor] info = [int(entry) for entry in info_tensor]
self.num_buckets = info[0] self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1] self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])] self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
# Technically, active_i_buckets' work is done. But the information is still useful to # Technically, active_i_buckets' work is done. But the information is still useful to
# keep around. Therefore, refresh active_i_buckets based on rank 0 as well. # keep around. Therefore, refresh active_i_buckets based on rank 0 as well.
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])] self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:] flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0 flat_i = 0
for bucket_idx in range(self.num_buckets): for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]): for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i] param_i = flattened_buckets[flat_i]
self.active_i_buckets[bucket_idx][bucket_loc] = param_i self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc) self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)
flat_i += 1 flat_i += 1
def create_hooks(self): def create_hooks(self):
# Fallback hook that's only called at the end of backward. # Fallback hook that's only called at the end of backward.
# Used if you deliberately want to delay allreduces to the end, or to refresh the # Used if you deliberately want to delay allreduces to the end, or to refresh the
# bucket structure that will be used to overlap communication with computation in later # bucket structure that will be used to overlap communication with computation in later
# iterations. # iterations.
def allreduce_params(): def allreduce_params():
...@@ -307,9 +322,10 @@ class DistributedDataParallel(Module): ...@@ -307,9 +322,10 @@ class DistributedDataParallel(Module):
def overlapping_backward_epilogue(): def overlapping_backward_epilogue():
self.reduction_stream.record_event(self.reduction_event) for stream, event in zip(self.bucket_streams, self.bucket_events):
torch.cuda.current_stream().wait_event(self.reduction_event) stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
# Sanity checks that all the buckets were kicked off # Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets: if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format( raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format(
...@@ -319,7 +335,7 @@ class DistributedDataParallel(Module): ...@@ -319,7 +335,7 @@ class DistributedDataParallel(Module):
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes): for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
if actual != expected: if actual != expected:
raise RuntimeError("Some param buckets were not allreduced.") raise RuntimeError("Some param buckets were not allreduced.")
self.grad_accs = [] self.grad_accs = []
for param in self.module.parameters(): for param in self.module.parameters():
...@@ -329,6 +345,9 @@ class DistributedDataParallel(Module): ...@@ -329,6 +345,9 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook")
if self.delay_allreduce or self.needs_refresh: if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between # TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True? # each forward, e.g., backward passes with retain_graph=True?
...@@ -339,8 +358,8 @@ class DistributedDataParallel(Module): ...@@ -339,8 +358,8 @@ class DistributedDataParallel(Module):
# Float, half, and double tensors are grouped into buckets separately. # Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()] current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i) self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False ship_tmp_bucket = False
if self.custom_allreduce_triggers: if self.custom_allreduce_triggers:
...@@ -357,46 +376,68 @@ class DistributedDataParallel(Module): ...@@ -357,46 +376,68 @@ class DistributedDataParallel(Module):
self.active_i_buckets.append(self.tmp_buckets[current_type]) self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = [] self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0 self.tmp_numels[current_type] = 0
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True self.callback_queued = True
else: else:
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue) Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True self.callback_queued = True
self.comm_ready_buckets(param) self.comm_ready_buckets(param)
if self.prof:
torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
wrapper(param) wrapper(param)
def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx]
else:
return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_events[bucket_idx]
else:
return self.bucket_events[0]
def allreduce_bucket(self, bucket): def allreduce_bucket(self, bucket):
tensor = flatten(bucket) tensor = flatten(bucket)
tensor_to_allreduce = tensor tensor_to_allreduce = tensor
if self.allreduce_always_fp32: if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float() tensor_to_allreduce = tensor.float()
if self.gradient_predivide_factor != 1.0: if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor) tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce) if self.allreduce_different_streams and self.bucket_pgs:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx])
else:
dist.all_reduce(tensor_to_allreduce)
if self.gradient_average: if self.gradient_average:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size) tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce: if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce) tensor.copy_(tensor_to_allreduce)
return tensor return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx=-1): def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
allreduced = self.allreduce_bucket(bucket) allreduced = self.allreduce_bucket(bucket, bucket_idx)
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None: if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
...@@ -420,11 +461,11 @@ class DistributedDataParallel(Module): ...@@ -420,11 +461,11 @@ class DistributedDataParallel(Module):
split_buckets = split_half_float_double(grads) split_buckets = split_half_float_double(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False, # If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the # this will only be done during the first backward pass, ignored by the
# training script, and overwritten in the next forward pass. So it's harmless. # training script, and overwritten in the next forward pass. So it's harmless.
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
self.allreduce_buffers = [None for _ in range(len(split_buckets))] self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets): for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i) allreduced = self.allreduce_maybe_retain(bucket, i)
...@@ -432,6 +473,8 @@ class DistributedDataParallel(Module): ...@@ -432,6 +473,8 @@ class DistributedDataParallel(Module):
def comm_ready_buckets(self, param): def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR. # Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream()) # self.reduction_stream.wait_stream(torch.cuda.current_stream())
if self.prof:
torch.cuda.nvtx.range_push("comm_ready_buckets")
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)] bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
...@@ -444,9 +487,11 @@ class DistributedDataParallel(Module): ...@@ -444,9 +487,11 @@ class DistributedDataParallel(Module):
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]: if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket: if bucket_idx == self.next_bucket:
torch.cuda.current_stream().record_event(self.reduction_event) bucket_stream = self._stream_this_bucket(bucket_idx)
self.reduction_stream.wait_event(self.reduction_event) bucket_event = self._event_this_bucket(bucket_idx)
with torch.cuda.stream(self.reduction_stream): torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx) self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1 self.next_bucket += 1
...@@ -462,16 +507,22 @@ class DistributedDataParallel(Module): ...@@ -462,16 +507,22 @@ class DistributedDataParallel(Module):
elif i == self.next_bucket: elif i == self.next_bucket:
self.allreduce_maybe_retain(self.buckets[i], i) self.allreduce_maybe_retain(self.buckets[i], i)
self.ready_buckets_not_reduced.remove(i) self.ready_buckets_not_reduced.remove(i)
self.next_bucket += 1 self.next_bucket += 1
else: else:
raise ValueError("i should always be >= next_bucket") raise ValueError("i should always be >= next_bucket")
else: else:
self.ready_buckets_not_reduced.add(bucket_idx) self.ready_buckets_not_reduced.add(bucket_idx)
if self.prof:
torch.cuda.nvtx.range_pop()
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs) result = self.module(*inputs, **kwargs)
if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self.delay_allreduce: if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
...@@ -479,7 +530,7 @@ class DistributedDataParallel(Module): ...@@ -479,7 +530,7 @@ class DistributedDataParallel(Module):
# Forward has the authority to set needs_refresh to True, but only allreduce_params # Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False. # in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer. # Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ((not self.active_params) or if ((not self.active_params) or
(len(param_list) != len(self.active_params)) or (len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])): any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True self.needs_refresh = True
...@@ -490,19 +541,53 @@ class DistributedDataParallel(Module): ...@@ -490,19 +541,53 @@ class DistributedDataParallel(Module):
self.tmp_buckets = [[], [], []] # [running half, float, double buckets] self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0] self.tmp_numels = [0, 0, 0]
self.bucket_sizes = [] self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {} self.param_id_to_bucket = {}
self.bucket_pgs = []
self.bucket_streams = []
self.bucket_events = []
else: else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])] self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_buckets)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_buckets)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_buckets)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)] self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers): if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)] self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0 self.next_bucket = 0
self.ready_buckets_not_reduced = set() self.ready_buckets_not_reduced = set()
self.active_params = param_list self.active_params = param_list
self.callback_queued = False self.callback_queued = False
if self.prof:
torch.cuda.nvtx.range_pop()
return result return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment