Unverified Commit 175fdeb0 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature] Unit test with and without buckets for all ShardedDDP unit tests (#400)

* test with and without buckets for all the shardedDDP unit tests
* parametrize all the things
* refactoring, adding even more  combinations at times
* handle hosts not having cuda
parent 4396ef4a
...@@ -51,9 +51,9 @@ class ShardedDataParallel(nn.Module): ...@@ -51,9 +51,9 @@ class ShardedDataParallel(nn.Module):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed, Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state or the training restarts from a saved state
reduce_buffer_size (int): reduce_buffer_size (int):
The max size of the buffer used to batch the small parameter tensors, in number of elements (default 8M). The max size of the buffer used to batch the small parameter tensors, in number of elements (default 0 - unused).
this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded. this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
Set to 0 to remove all bucketing. Set to 0 to remove all bucketing, 1M to 8M is usually reasonable.
auto_refresh_trainable (bool): auto_refresh_trainable (bool):
(default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP (default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP
and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime
...@@ -98,7 +98,7 @@ class ShardedDataParallel(nn.Module): ...@@ -98,7 +98,7 @@ class ShardedDataParallel(nn.Module):
process_group: Any = None, process_group: Any = None,
broadcast_buffers: bool = True, broadcast_buffers: bool = True,
sync_models_at_startup: bool = True, sync_models_at_startup: bool = True,
reduce_buffer_size: int = 2 ** 23, reduce_buffer_size: int = 0,
auto_refresh_trainable: bool = True, auto_refresh_trainable: bool = True,
): ):
super().__init__() super().__init__()
...@@ -111,6 +111,7 @@ class ShardedDataParallel(nn.Module): ...@@ -111,6 +111,7 @@ class ShardedDataParallel(nn.Module):
# Handle a no_sync() context which prevents the gradient synchronization, # Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place # accumulate in place
self.should_accumulate_grads = False self.should_accumulate_grads = False
self.accumulate_grads_flipped = False
# Communication related attributes # Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD self.process_group = process_group if process_group is not None else dist.group.WORLD
...@@ -153,10 +154,6 @@ class ShardedDataParallel(nn.Module): ...@@ -153,10 +154,6 @@ class ShardedDataParallel(nn.Module):
# - setup buckets and tensor views # - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()]) model_size = sum([p.numel() for p in self.module.parameters()])
if dist.get_world_size(self.process_group) <= 8:
logging.info("Assuming single node environment. De-activating ShardedDDP buckets")
reduce_buffer_size = 0
self.buffer_max_size = min(reduce_buffer_size, model_size) self.buffer_max_size = min(reduce_buffer_size, model_size)
logging.info( logging.info(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format( "ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
...@@ -230,6 +227,11 @@ class ShardedDataParallel(nn.Module): ...@@ -230,6 +227,11 @@ class ShardedDataParallel(nn.Module):
.. note:: .. note::
This method modifies the module in-place. This method modifies the module in-place.
.. warning:
Device changes are not supported, and this will raise an exception. The issue in that case is not
really ShardedDDP, but OSS which will not be aware of the device change, and whose buffers will be
in a broken state.
Arguments: Arguments:
device (:class:`torch.device`): the desired device of the parameters and buffers in this module. device (:class:`torch.device`): the desired device of the parameters and buffers in this module.
dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers. dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers.
...@@ -237,14 +239,18 @@ class ShardedDataParallel(nn.Module): ...@@ -237,14 +239,18 @@ class ShardedDataParallel(nn.Module):
Returns: Returns:
Module: self. Module: self.
""" """
for device in self.buckets.keys(): assert device in self.buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
for bucket in self.buckets[device]: assert (
len(self.buckets.keys()) == 1
), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self.buckets.keys():
for bucket in self.buckets[_device]:
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking) bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device) self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """ """ If the module trainability has changed, update all the assumptions """
...@@ -320,7 +326,7 @@ class ShardedDataParallel(nn.Module): ...@@ -320,7 +326,7 @@ class ShardedDataParallel(nn.Module):
See :meth:`torch.optim.Optimizer.zero_grad` for details. See :meth:`torch.optim.Optimizer.zero_grad` for details.
""" """
for index, trainable_param in enumerate(self._trainable_params): for index, trainable_param in enumerate(self._all_params):
if set_to_none and not self._should_bucket_grad[index]: if set_to_none and not self._should_bucket_grad[index]:
trainable_param.grad = None trainable_param.grad = None
elif trainable_param.grad is not None: elif trainable_param.grad is not None:
...@@ -339,6 +345,7 @@ class ShardedDataParallel(nn.Module): ...@@ -339,6 +345,7 @@ class ShardedDataParallel(nn.Module):
old_should_accumulate_grads = self.should_accumulate_grads old_should_accumulate_grads = self.should_accumulate_grads
self.should_accumulate_grads = True self.should_accumulate_grads = True
yield yield
self.accumulate_grads_flipped = self.should_accumulate_grads != old_should_accumulate_grads
self.should_accumulate_grads = old_should_accumulate_grads self.should_accumulate_grads = old_should_accumulate_grads
@torch.no_grad() @torch.no_grad()
...@@ -352,13 +359,19 @@ class ShardedDataParallel(nn.Module): ...@@ -352,13 +359,19 @@ class ShardedDataParallel(nn.Module):
assert self._bucket_list is not None assert self._bucket_list is not None
for bucket in self._bucket_list: for bucket in self._bucket_list:
assert not self.training or self.should_accumulate_grads or bucket.sent, ( assert (
"A bucket failed to be sent, probably unused parameters." self.accumulate_grads_flipped or not self.training or self.should_accumulate_grads or bucket.sent
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-" ), (
"A bucket failed to be sent, probably unused parameters. "
+ "Either mark the unused parameter as not trainable (`.requires_grad = False`) "
+ "or de-activate ShardedDDP buckets -set `reduce_buffer_size` to 0-"
) )
bucket.reset() bucket.reset()
if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]: def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """ """ Look up where this parameter belongs to """
for optim in self.sharded_optimizers: for optim in self.sharded_optimizers:
...@@ -394,10 +407,12 @@ class ShardedDataParallel(nn.Module): ...@@ -394,10 +407,12 @@ class ShardedDataParallel(nn.Module):
param.grad = None param.grad = None
# Async reduce for this buffer, log the future # Async reduce for this buffer, log the future
dst_global_rank = OSS.get_global_rank(self.process_group, dst_rank)
self._work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True tensor=param.grad.data, dst=dst_global_rank, group=self.process_group, async_op=True
), ),
callback=cleanup, callback=cleanup,
) )
...@@ -435,7 +450,10 @@ class ShardedDataParallel(nn.Module): ...@@ -435,7 +450,10 @@ class ShardedDataParallel(nn.Module):
self._work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True, tensor=bucket.buffer,
dst=bucket.destination,
group=self.process_group,
async_op=True,
), ),
callback=None, callback=None,
) )
...@@ -470,33 +488,11 @@ class ShardedDataParallel(nn.Module): ...@@ -470,33 +488,11 @@ class ShardedDataParallel(nn.Module):
p_tmp = param.expand_as(param) p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param]) dst_rank = self._trainable_param_to_rank[param]
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank)) grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank))
self._grad_accs.append(grad_acc) # keep this function in scope self._grad_accs.append(grad_acc) # keep this function in scope
# Add a hook on the module to flush the buckets, if needed
if self.use_buckets:
def bucket_flush(*_: Any) -> None:
assert self._bucket_list is not None
handle = None
for bucket in self._bucket_list:
if not bucket.sent:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket.buffer.mul_(self.world_size_scaling)
bucket.sent = True
handle = dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
)
# Only wait on the last handle
if handle:
handle.wait()
self.module.register_backward_hook(bucket_flush)
@torch.no_grad() @torch.no_grad()
def _sync_params_and_buffers(self) -> None: def _sync_params_and_buffers(self) -> None:
""" """
...@@ -545,7 +541,7 @@ class ShardedDataParallel(nn.Module): ...@@ -545,7 +541,7 @@ class ShardedDataParallel(nn.Module):
for param in self._trainable_params: for param in self._trainable_params:
device = param.device device = param.device
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param]) dst_rank = self._trainable_param_to_rank[param]
if param.device not in self.buckets.keys(): if param.device not in self.buckets.keys():
self.buckets[param.device] = [ self.buckets[param.device] = [
...@@ -554,7 +550,7 @@ class ShardedDataParallel(nn.Module): ...@@ -554,7 +550,7 @@ class ShardedDataParallel(nn.Module):
] ]
bucket = self.buckets[device][dst_rank] bucket = self.buckets[device][dst_rank]
bucket.destination = dst_rank bucket.destination = OSS.get_global_rank(self.process_group, dst_rank)
# Criteria to decide whether this parameter is to be bucketed or not: # Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket # - enough room in the bucket
......
...@@ -412,7 +412,7 @@ class OSS(Optimizer): ...@@ -412,7 +412,7 @@ class OSS(Optimizer):
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" Updates the partitioning and communication patterns if the trainability (`requires_grad`) """ Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed of some parameters changed.
""" """
# Create the optim which will work on the param shard # Create the optim which will work on the param shard
......
...@@ -54,7 +54,7 @@ skip_if_single_gpu = pytest.mark.skipif( ...@@ -54,7 +54,7 @@ skip_if_single_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required" not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required"
) )
skip_if_less_four_gpu = pytest.mark.skipif( skip_if_less_than_four_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason="4 GPUs or more required" not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason="4 GPUs or more required"
) )
...@@ -67,6 +67,11 @@ skip_if_py39_no_cuda = pytest.mark.skipif( ...@@ -67,6 +67,11 @@ skip_if_py39_no_cuda = pytest.mark.skipif(
reason="Python3.9 wo CUDA is skipped", reason="Python3.9 wo CUDA is skipped",
) )
available_devices = ["cpu"]
if torch.cuda.is_available():
available_devices.append("cuda")
_, filename_mpi = tempfile.mkstemp() _, filename_mpi = tempfile.mkstemp()
...@@ -418,3 +423,31 @@ def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, ...@@ -418,3 +423,31 @@ def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module,
for b_a, b_b in zip(model_a.buffers(), model_b.buffers()): for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
def check_same_models_across_ranks(
model: torch.nn.Module, process_group: Any, params_should_be_equal: bool, check_broadcast_buffers: bool
) -> None:
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
for param in model.parameters():
# collect the params across the rank
receptacle = [param.clone() for _ in range(world_size)]
dist.all_gather(receptacle, param, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks"
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if check_broadcast_buffers:
for buffer in model.buffers():
receptacle = [buffer.clone() for _ in range(world_size)]
dist.all_gather(receptacle, buffer, group=process_group)
if rank == 0:
for sync_b in receptacle[1:]:
assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_b)
), "Models differ in between ranks"
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment