Unverified Commit 175fdeb0 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature] Unit test with and without buckets for all ShardedDDP unit tests (#400)

* test with and without buckets for all the shardedDDP unit tests
* parametrize all the things
* refactoring, adding even more  combinations at times
* handle hosts not having cuda
parent 4396ef4a
......@@ -51,9 +51,9 @@ class ShardedDataParallel(nn.Module):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state
reduce_buffer_size (int):
The max size of the buffer used to batch the small parameter tensors, in number of elements (default 8M).
The max size of the buffer used to batch the small parameter tensors, in number of elements (default 0 - unused).
this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
Set to 0 to remove all bucketing.
Set to 0 to remove all bucketing, 1M to 8M is usually reasonable.
auto_refresh_trainable (bool):
(default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP
and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime
......@@ -98,7 +98,7 @@ class ShardedDataParallel(nn.Module):
process_group: Any = None,
broadcast_buffers: bool = True,
sync_models_at_startup: bool = True,
reduce_buffer_size: int = 2 ** 23,
reduce_buffer_size: int = 0,
auto_refresh_trainable: bool = True,
):
super().__init__()
......@@ -111,6 +111,7 @@ class ShardedDataParallel(nn.Module):
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
self.should_accumulate_grads = False
self.accumulate_grads_flipped = False
# Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD
......@@ -153,10 +154,6 @@ class ShardedDataParallel(nn.Module):
# - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()])
if dist.get_world_size(self.process_group) <= 8:
logging.info("Assuming single node environment. De-activating ShardedDDP buckets")
reduce_buffer_size = 0
self.buffer_max_size = min(reduce_buffer_size, model_size)
logging.info(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
......@@ -230,6 +227,11 @@ class ShardedDataParallel(nn.Module):
.. note::
This method modifies the module in-place.
.. warning:
Device changes are not supported, and this will raise an exception. The issue in that case is not
really ShardedDDP, but OSS which will not be aware of the device change, and whose buffers will be
in a broken state.
Arguments:
device (:class:`torch.device`): the desired device of the parameters and buffers in this module.
dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers.
......@@ -237,14 +239,18 @@ class ShardedDataParallel(nn.Module):
Returns:
Module: self.
"""
for device in self.buckets.keys():
for bucket in self.buckets[device]:
assert device in self.buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
assert (
len(self.buckets.keys()) == 1
), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self.buckets.keys():
for bucket in self.buckets[_device]:
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device)
self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """
......@@ -320,7 +326,7 @@ class ShardedDataParallel(nn.Module):
See :meth:`torch.optim.Optimizer.zero_grad` for details.
"""
for index, trainable_param in enumerate(self._trainable_params):
for index, trainable_param in enumerate(self._all_params):
if set_to_none and not self._should_bucket_grad[index]:
trainable_param.grad = None
elif trainable_param.grad is not None:
......@@ -339,6 +345,7 @@ class ShardedDataParallel(nn.Module):
old_should_accumulate_grads = self.should_accumulate_grads
self.should_accumulate_grads = True
yield
self.accumulate_grads_flipped = self.should_accumulate_grads != old_should_accumulate_grads
self.should_accumulate_grads = old_should_accumulate_grads
@torch.no_grad()
......@@ -352,13 +359,19 @@ class ShardedDataParallel(nn.Module):
assert self._bucket_list is not None
for bucket in self._bucket_list:
assert not self.training or self.should_accumulate_grads or bucket.sent, (
"A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
assert (
self.accumulate_grads_flipped or not self.training or self.should_accumulate_grads or bucket.sent
), (
"A bucket failed to be sent, probably unused parameters. "
+ "Either mark the unused parameter as not trainable (`.requires_grad = False`) "
+ "or de-activate ShardedDDP buckets -set `reduce_buffer_size` to 0-"
)
bucket.reset()
if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """
for optim in self.sharded_optimizers:
......@@ -394,10 +407,12 @@ class ShardedDataParallel(nn.Module):
param.grad = None
# Async reduce for this buffer, log the future
dst_global_rank = OSS.get_global_rank(self.process_group, dst_rank)
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True
tensor=param.grad.data, dst=dst_global_rank, group=self.process_group, async_op=True
),
callback=cleanup,
)
......@@ -435,7 +450,10 @@ class ShardedDataParallel(nn.Module):
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
tensor=bucket.buffer,
dst=bucket.destination,
group=self.process_group,
async_op=True,
),
callback=None,
)
......@@ -470,33 +488,11 @@ class ShardedDataParallel(nn.Module):
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param])
dst_rank = self._trainable_param_to_rank[param]
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank))
self._grad_accs.append(grad_acc) # keep this function in scope
# Add a hook on the module to flush the buckets, if needed
if self.use_buckets:
def bucket_flush(*_: Any) -> None:
assert self._bucket_list is not None
handle = None
for bucket in self._bucket_list:
if not bucket.sent:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket.buffer.mul_(self.world_size_scaling)
bucket.sent = True
handle = dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
)
# Only wait on the last handle
if handle:
handle.wait()
self.module.register_backward_hook(bucket_flush)
@torch.no_grad()
def _sync_params_and_buffers(self) -> None:
"""
......@@ -545,7 +541,7 @@ class ShardedDataParallel(nn.Module):
for param in self._trainable_params:
device = param.device
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param])
dst_rank = self._trainable_param_to_rank[param]
if param.device not in self.buckets.keys():
self.buckets[param.device] = [
......@@ -554,7 +550,7 @@ class ShardedDataParallel(nn.Module):
]
bucket = self.buckets[device][dst_rank]
bucket.destination = dst_rank
bucket.destination = OSS.get_global_rank(self.process_group, dst_rank)
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
......
......@@ -412,7 +412,7 @@ class OSS(Optimizer):
def refresh_trainable(self) -> None:
""" Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed
of some parameters changed.
"""
# Create the optim which will work on the param shard
......
......@@ -54,7 +54,7 @@ skip_if_single_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required"
)
skip_if_less_four_gpu = pytest.mark.skipif(
skip_if_less_than_four_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason="4 GPUs or more required"
)
......@@ -67,6 +67,11 @@ skip_if_py39_no_cuda = pytest.mark.skipif(
reason="Python3.9 wo CUDA is skipped",
)
available_devices = ["cpu"]
if torch.cuda.is_available():
available_devices.append("cuda")
_, filename_mpi = tempfile.mkstemp()
......@@ -418,3 +423,31 @@ def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module,
for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
def check_same_models_across_ranks(
model: torch.nn.Module, process_group: Any, params_should_be_equal: bool, check_broadcast_buffers: bool
) -> None:
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
for param in model.parameters():
# collect the params across the rank
receptacle = [param.clone() for _ in range(world_size)]
dist.all_gather(receptacle, param, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks"
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if check_broadcast_buffers:
for buffer in model.buffers():
receptacle = [buffer.clone() for _ in range(world_size)]
dist.all_gather(receptacle, buffer, group=process_group)
if rank == 0:
for sync_b in receptacle[1:]:
assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_b)
), "Models differ in between ranks"
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment