Unverified Commit fb4eca19 authored by tmarkstrum's avatar tmarkstrum Committed by GitHub
Browse files

[Fix][FSDP]fixed padding size of input tensor for reduce scatter (#907)



* fixed padding size of input tensor for reduce scatter, and fixed an error that assigned wrong group

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* added changelog

* fixed some commit.

* added unit test to ensure the reduce_scatter process group size is correct in default cases. And fall back to default process grouop when the reduce_scatter process group has the wrong size.

* throw an error instead of rolling back to use default process group for reduce_scatter_process_group

* Revert "throw an error instead of rolling back to use default process group for reduce_scatter_process_group"

This reverts commit eab5620da3b726ea55d3088ae4ca10d94dcdf4d9.

* added check for None to avoid unit test failure

* fixed an error to avoid the unit tests failure
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
parent 0044372c
......@@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Enabled reduce_scatter operation to overlap with all_gather stream and computation stream in backward propagation.
This change increases FSDP's throughput. To roll back this change, please pass ProcessGroupName.default to the process_group_reduce_scatter API parameter in FSDP [#897]
### Fixed
- FSDP: Fixed the issue that the padding size of the input tensor of the reduce scatter is not equal to the reduce scatter process group size # [#907]
## [0.4.4] - 2021-12-21
### Fixed
......
......@@ -198,6 +198,8 @@ class FullyShardedDataParallel(nn.Module):
If it is a specific ProcessGroup, the reduce_scatter operates on this ProcessGroup, and the overlap still happens.
To disable the overlap feature, set the process group to ProcessGroupName.default. In this case, the reduce_scatter
operation uses the same process group with the default group.
If reduce scatter process group size is differnt with the default process group size, the reduce_scatter
operation rolls back to use the same process group with the default process group.
reshard_after_forward (bool, Optional):
if ``True``, reshard parameters after the forward pass. This saves
memory but slows training. This is only relevant when resharding
......@@ -298,7 +300,8 @@ class FullyShardedDataParallel(nn.Module):
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
process_group_reduce_scatter: Union[ProcessGroup, ProcessGroupName] = ProcessGroupName.reduce_scatter,
# The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter,
reshard_after_forward: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
......@@ -325,13 +328,35 @@ class FullyShardedDataParallel(nn.Module):
# the rest of operations. The overlap feature in the backward propagation is disabled.
if process_group_reduce_scatter == ProcessGroupName.default:
self.process_group_reduce_scatter = self.process_group
# If ProcessGroupName.reduce_scatter is passed in, the reduce_scatter use a seperate process group
# so that the overlap feature in the backward propagagion is enabled.
elif process_group_reduce_scatter == ProcessGroupName.reduce_scatter:
self.process_group_reduce_scatter = get_process_group_cached(ProcessGroupName.reduce_scatter)
else:
self.process_group_reduce_scatter = process_group
# If a specific process group is passed in, the reduce_scatter will use the passed in process group.
if isinstance(process_group_reduce_scatter, ProcessGroup):
self.process_group_reduce_scatter = process_group_reduce_scatter
else:
if not hasattr(process_group_reduce_scatter, "allgather") and hasattr(
process_group_reduce_scatter, "rank"
):
# Likely a dummy pg for unit test
self.process_group_reduce_scatter = process_group_reduce_scatter
else:
raise TypeError("unsupported type for reduce_scatter process group")
self.rank = self.process_group.rank()
self.world_size = self.process_group.size()
# In a unit test dummy enviromnent, the process_group_reduce_scatter can be None.
if self.process_group_reduce_scatter is not None:
reduce_scatter_group_size = self.process_group_reduce_scatter.size()
# Roll back to use the default process group for reduce scatter operation when the world size and reduce scatter process group size are differnt.
if self.world_size != reduce_scatter_group_size:
self.process_group_reduce_scatter = self.process_group
logging.warn(
"Rolled back to use the default process group for the reduce scatter operation because the reduce_scatter process group"
f"size is {reduce_scatter_group_size}, which is different with the world size {self.world_size}. Please make sure the process_group"
"parameter uses all the available ranks for the optimized performance."
)
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
......@@ -1617,7 +1642,7 @@ class FullyShardedDataParallel(nn.Module):
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
param.grad = None
callback_fn = functools.partial(self._post_reduction_hook, param)
grad_chunks = chunk_and_pad(grad, self.world_size)
grad_chunks = chunk_and_pad(grad, self.process_group_reduce_scatter.size())
self._reducer.reduce_scatter_async(
grad_chunks, group=self.process_group_reduce_scatter, callback_fn=callback_fn
)
......
......@@ -462,6 +462,24 @@ class TestParamInit(DistributedTest):
assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init"
class TestReduceScatterProcessGroup(DistributedTest):
def test_reduce_scatter_process_group_size(self):
"""Ensure that reduce_scatter_process_group same size with the world size."""
test_fn = functools.partial(self._test_reduce_scatter_process_group_size, config={})
spawn_and_init(test_fn, world_sizes=[2])
@classmethod
def _test_reduce_scatter_process_group_size(self, rank, group, config):
model = self._get_model(group, config)
assert model.process_group_reduce_scatter.size() == model.world_size
@classmethod
def _get_model(self, group, config):
with torch.no_grad(): # required for multiprocessing
model = NestedWrappedModule(group, wrapper_config=config)
return FullyShardedDataParallel(model, group, **config)
class TestSerialization(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True], [False, True]], name_func=rename_test)
def test_pickle(self, mixed_precision, cpu_offload):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment