Commit 8521bb22 authored by Michael Carilli's avatar Michael Carilli
Browse files

Patching in changes to enable multiple allreduces in flight

parent 61b8a0fd
# Introduction
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
The intention of Apex is to make up-to-date utilities available to
The intention of Apex is to make up-to-date utilities available to
users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
......@@ -29,7 +29,7 @@ different flags to `amp.initialize`.
## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
......
......@@ -133,7 +133,7 @@ def _initialize(models, optimizers, properties, num_losses=1):
if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models)
check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can
......@@ -163,7 +163,7 @@ def _initialize(models, optimizers, properties, num_losses=1):
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
......
......@@ -44,7 +44,7 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced)
......@@ -54,7 +54,7 @@ def split_half_float_double(tensors):
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
buckets.append(bucket)
return buckets
def split_by_type(tensors):
......@@ -69,12 +69,12 @@ def split_by_type(tensors):
# flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None):
buckets = split_by_type(tensors)
for tp in buckets:
bucket = buckets[tp]
apply_flat_dist_call(bucket, call, extra_args)
def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor)
......@@ -85,7 +85,7 @@ def extract_tensors(maybe_tensor, tensor_list):
except TypeError:
return
class Reducer(object):
"""
:class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters
......@@ -93,13 +93,13 @@ class Reducer(object):
Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce
parameters during ``backward()``.
Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually.
This enables, for example, delaying the allreduce to be carried out every
This enables, for example, delaying the allreduce to be carried out every
several iterations instead of every single iteration.
Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces
Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces
over the number of participating processes.
:class:`Reducer` is designed to work with the upstream launch utility script
:class:`Reducer` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
......@@ -109,7 +109,7 @@ class Reducer(object):
Args:
module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training.
"""
def __init__(self, module_or_grads_list):
if isinstance(module_or_grads_list, Module):
self.module = module_or_grads_list
......@@ -119,26 +119,26 @@ class Reducer(object):
self.module = None
self.grads = []
extract_tensors(module_or_grads_list, self.grads)
def reduce(self):
if self.module:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
else:
flat_dist_call(self.grads, dist.all_reduce)
class DistributedDataParallel(Module):
"""
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are
allreduced and averaged over processes during ``backward()``.
:class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by
:class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by
overlapping communication with computation during ``backward()`` and bucketing smaller gradient
transfers to reduce the total number of transfers required.
:class:`DistributedDataParallel` is designed to work with the upstream launch utility script
:class:`DistributedDataParallel` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
......@@ -161,20 +161,22 @@ class DistributedDataParallel(Module):
"""
def __init__(self,
module,
message_size=10000000,
delay_allreduce=False,
def __init__(self,
module,
message_size=10000000,
delay_allreduce=False,
shared_param=None,
allreduce_trigger_params=None,
retain_allreduce_buffers=False,
allreduce_always_fp32=False,
allreduce_different_streams=False,
gradient_average=True,
gradient_predivide_factor=1.0,
gradient_average_split_factor=None):
gradient_average_split_factor=None,
prof=False):
super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around
# Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and
# https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86
if hasattr(dist, "get_backend"):
......@@ -184,13 +186,20 @@ class DistributedDataParallel(Module):
else:
self.backend_enum_holder = dist.Backend
else:
self._backend = dist._backend
self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.prof = prof
if allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
self.allreduce_different_streams = allreduce_different_streams
if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
if gradient_average_split_factor is not None:
print("Warning: gradient_average_split_factor has been renamed to gradient_predivide_factor. For now, gradient_average_split_factor will also work, but please update to gradient_predivide_factor instead.")
......@@ -206,25 +215,25 @@ class DistributedDataParallel(Module):
self.custom_allreduce_triggers = False
if allreduce_trigger_params is not None:
if delay_allreduce:
raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.")
raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.")
self.custom_allreduce_triggers = True
self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params])
self.delay_allreduce = delay_allreduce
self.message_size = message_size
self.reduction_stream = torch.cuda.Stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.bucket_streams = []
self.bucket_events = []
self.module = module
if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.active_params = []
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2}
......@@ -241,19 +250,25 @@ class DistributedDataParallel(Module):
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
if allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce:
self.needs_refresh = True
self.bucket_streams = []
self.bucket_events = []
def __getstate__(self):
attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream']
del attrs['self.reduction_event']
del attrs['self.bucket_streams']
del attrs['self.bucket_events']
return attrs
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
def sync_bucket_structure(self):
# Append leftover buckets
for tmp_bucket in self.tmp_buckets:
......@@ -263,8 +278,8 @@ class DistributedDataParallel(Module):
self.num_buckets = len(self.active_i_buckets)
self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes +
info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes +
list(chain(*self.active_i_buckets)))
dist.broadcast(info_tensor, 0)
......@@ -272,27 +287,27 @@ class DistributedDataParallel(Module):
info = [int(entry) for entry in info_tensor]
self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
# Technically, active_i_buckets' work is done. But the information is still useful to
# keep around. Therefore, refresh active_i_buckets based on rank 0 as well.
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0
for bucket_idx in range(self.num_buckets):
for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i]
self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)
flat_i += 1
flat_i += 1
def create_hooks(self):
# Fallback hook that's only called at the end of backward.
# Used if you deliberately want to delay allreduces to the end, or to refresh the
# Used if you deliberately want to delay allreduces to the end, or to refresh the
# bucket structure that will be used to overlap communication with computation in later
# iterations.
def allreduce_params():
......@@ -307,9 +322,10 @@ class DistributedDataParallel(Module):
def overlapping_backward_epilogue():
self.reduction_stream.record_event(self.reduction_event)
torch.cuda.current_stream().wait_event(self.reduction_event)
for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
# Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format(
......@@ -319,7 +335,7 @@ class DistributedDataParallel(Module):
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
if actual != expected:
raise RuntimeError("Some param buckets were not allreduced.")
self.grad_accs = []
for param in self.module.parameters():
......@@ -329,6 +345,9 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook")
if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
......@@ -339,8 +358,8 @@ class DistributedDataParallel(Module):
# Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i)
self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False
if self.custom_allreduce_triggers:
......@@ -357,46 +376,68 @@ class DistributedDataParallel(Module):
self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0
if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True
else:
if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True
self.callback_queued = True
self.comm_ready_buckets(param)
if self.prof:
torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc)
wrapper(param)
def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx]
else:
return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_events[bucket_idx]
else:
return self.bucket_events[0]
def allreduce_bucket(self, bucket):
tensor = flatten(bucket)
tensor_to_allreduce = tensor
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
tensor_to_allreduce = tensor.float()
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce)
if self.allreduce_different_streams and self.bucket_pgs:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx])
else:
dist.all_reduce(tensor_to_allreduce)
if self.gradient_average:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
allreduced = self.allreduce_bucket(bucket)
allreduced = self.allreduce_bucket(bucket, bucket_idx)
if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled "
......@@ -420,11 +461,11 @@ class DistributedDataParallel(Module):
split_buckets = split_half_float_double(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the
# training script, and overwritten in the next forward pass. So it's harmless.
# this will only be done during the first backward pass, ignored by the
# training script, and overwritten in the next forward pass. So it's harmless.
if self.retain_allreduce_buffers:
self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i)
......@@ -432,6 +473,8 @@ class DistributedDataParallel(Module):
def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream())
if self.prof:
torch.cuda.nvtx.range_push("comm_ready_buckets")
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
......@@ -444,9 +487,11 @@ class DistributedDataParallel(Module):
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket:
torch.cuda.current_stream().record_event(self.reduction_event)
self.reduction_stream.wait_event(self.reduction_event)
with torch.cuda.stream(self.reduction_stream):
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1
......@@ -462,16 +507,22 @@ class DistributedDataParallel(Module):
elif i == self.next_bucket:
self.allreduce_maybe_retain(self.buckets[i], i)
self.ready_buckets_not_reduced.remove(i)
self.next_bucket += 1
self.next_bucket += 1
else:
raise ValueError("i should always be >= next_bucket")
else:
self.ready_buckets_not_reduced.add(bucket_idx)
if self.prof:
torch.cuda.nvtx.range_pop()
def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs)
if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
......@@ -479,7 +530,7 @@ class DistributedDataParallel(Module):
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ((not self.active_params) or
if ((not self.active_params) or
(len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True
......@@ -490,19 +541,53 @@ class DistributedDataParallel(Module):
self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0]
self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
self.bucket_pgs = []
self.bucket_streams = []
self.bucket_events = []
else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_buckets)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_buckets)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_buckets)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0
self.ready_buckets_not_reduced = set()
self.active_params = param_list
self.callback_queued = False
if self.prof:
torch.cuda.nvtx.range_pop()
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment