Unverified Commit fb4eca19 authored by tmarkstrum's avatar tmarkstrum Committed by GitHub
Browse files

[Fix][FSDP]fixed padding size of input tensor for reduce scatter (#907)



* fixed padding size of input tensor for reduce scatter, and fixed an error that assigned wrong group

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* added changelog

* fixed some commit.

* added unit test to ensure the reduce_scatter process group size is correct in default cases. And fall back to default process grouop when the reduce_scatter process group has the wrong size.

* throw an error instead of rolling back to use default process group for reduce_scatter_process_group

* Revert "throw an error instead of rolling back to use default process group for reduce_scatter_process_group"

This reverts commit eab5620da3b726ea55d3088ae4ca10d94dcdf4d9.

* added check for None to avoid unit test failure

* fixed an error to avoid the unit tests failure
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
parent 0044372c
...@@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Enabled reduce_scatter operation to overlap with all_gather stream and computation stream in backward propagation. - FSDP: Enabled reduce_scatter operation to overlap with all_gather stream and computation stream in backward propagation.
This change increases FSDP's throughput. To roll back this change, please pass ProcessGroupName.default to the process_group_reduce_scatter API parameter in FSDP [#897] This change increases FSDP's throughput. To roll back this change, please pass ProcessGroupName.default to the process_group_reduce_scatter API parameter in FSDP [#897]
### Fixed
- FSDP: Fixed the issue that the padding size of the input tensor of the reduce scatter is not equal to the reduce scatter process group size # [#907]
## [0.4.4] - 2021-12-21 ## [0.4.4] - 2021-12-21
### Fixed ### Fixed
......
...@@ -198,6 +198,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -198,6 +198,8 @@ class FullyShardedDataParallel(nn.Module):
If it is a specific ProcessGroup, the reduce_scatter operates on this ProcessGroup, and the overlap still happens. If it is a specific ProcessGroup, the reduce_scatter operates on this ProcessGroup, and the overlap still happens.
To disable the overlap feature, set the process group to ProcessGroupName.default. In this case, the reduce_scatter To disable the overlap feature, set the process group to ProcessGroupName.default. In this case, the reduce_scatter
operation uses the same process group with the default group. operation uses the same process group with the default group.
If reduce scatter process group size is differnt with the default process group size, the reduce_scatter
operation rolls back to use the same process group with the default process group.
reshard_after_forward (bool, Optional): reshard_after_forward (bool, Optional):
if ``True``, reshard parameters after the forward pass. This saves if ``True``, reshard parameters after the forward pass. This saves
memory but slows training. This is only relevant when resharding memory but slows training. This is only relevant when resharding
...@@ -298,7 +300,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -298,7 +300,8 @@ class FullyShardedDataParallel(nn.Module):
self, self,
module: nn.Module, module: nn.Module,
process_group: Optional[ProcessGroup] = None, process_group: Optional[ProcessGroup] = None,
process_group_reduce_scatter: Union[ProcessGroup, ProcessGroupName] = ProcessGroupName.reduce_scatter, # The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter,
reshard_after_forward: bool = True, reshard_after_forward: bool = True,
mixed_precision: bool = False, mixed_precision: bool = False,
fp32_reduce_scatter: bool = False, fp32_reduce_scatter: bool = False,
...@@ -325,13 +328,35 @@ class FullyShardedDataParallel(nn.Module): ...@@ -325,13 +328,35 @@ class FullyShardedDataParallel(nn.Module):
# the rest of operations. The overlap feature in the backward propagation is disabled. # the rest of operations. The overlap feature in the backward propagation is disabled.
if process_group_reduce_scatter == ProcessGroupName.default: if process_group_reduce_scatter == ProcessGroupName.default:
self.process_group_reduce_scatter = self.process_group self.process_group_reduce_scatter = self.process_group
# If ProcessGroupName.reduce_scatter is passed in, the reduce_scatter use a seperate process group
# so that the overlap feature in the backward propagagion is enabled.
elif process_group_reduce_scatter == ProcessGroupName.reduce_scatter: elif process_group_reduce_scatter == ProcessGroupName.reduce_scatter:
self.process_group_reduce_scatter = get_process_group_cached(ProcessGroupName.reduce_scatter) self.process_group_reduce_scatter = get_process_group_cached(ProcessGroupName.reduce_scatter)
else: else:
self.process_group_reduce_scatter = process_group # If a specific process group is passed in, the reduce_scatter will use the passed in process group.
if isinstance(process_group_reduce_scatter, ProcessGroup):
self.process_group_reduce_scatter = process_group_reduce_scatter
else:
if not hasattr(process_group_reduce_scatter, "allgather") and hasattr(
process_group_reduce_scatter, "rank"
):
# Likely a dummy pg for unit test
self.process_group_reduce_scatter = process_group_reduce_scatter
else:
raise TypeError("unsupported type for reduce_scatter process group")
self.rank = self.process_group.rank() self.rank = self.process_group.rank()
self.world_size = self.process_group.size() self.world_size = self.process_group.size()
# In a unit test dummy enviromnent, the process_group_reduce_scatter can be None.
if self.process_group_reduce_scatter is not None:
reduce_scatter_group_size = self.process_group_reduce_scatter.size()
# Roll back to use the default process group for reduce scatter operation when the world size and reduce scatter process group size are differnt.
if self.world_size != reduce_scatter_group_size:
self.process_group_reduce_scatter = self.process_group
logging.warn(
"Rolled back to use the default process group for the reduce scatter operation because the reduce_scatter process group"
f"size is {reduce_scatter_group_size}, which is different with the world size {self.world_size}. Please make sure the process_group"
"parameter uses all the available ranks for the optimized performance."
)
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter self.fp32_reduce_scatter = fp32_reduce_scatter
...@@ -1617,7 +1642,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1617,7 +1642,7 @@ class FullyShardedDataParallel(nn.Module):
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation. # unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
param.grad = None param.grad = None
callback_fn = functools.partial(self._post_reduction_hook, param) callback_fn = functools.partial(self._post_reduction_hook, param)
grad_chunks = chunk_and_pad(grad, self.world_size) grad_chunks = chunk_and_pad(grad, self.process_group_reduce_scatter.size())
self._reducer.reduce_scatter_async( self._reducer.reduce_scatter_async(
grad_chunks, group=self.process_group_reduce_scatter, callback_fn=callback_fn grad_chunks, group=self.process_group_reduce_scatter, callback_fn=callback_fn
) )
......
...@@ -462,6 +462,24 @@ class TestParamInit(DistributedTest): ...@@ -462,6 +462,24 @@ class TestParamInit(DistributedTest):
assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init" assert not objects_are_equal(ref_output, new_output), "new_output did not reflect change to param after init"
class TestReduceScatterProcessGroup(DistributedTest):
def test_reduce_scatter_process_group_size(self):
"""Ensure that reduce_scatter_process_group same size with the world size."""
test_fn = functools.partial(self._test_reduce_scatter_process_group_size, config={})
spawn_and_init(test_fn, world_sizes=[2])
@classmethod
def _test_reduce_scatter_process_group_size(self, rank, group, config):
model = self._get_model(group, config)
assert model.process_group_reduce_scatter.size() == model.world_size
@classmethod
def _get_model(self, group, config):
with torch.no_grad(): # required for multiprocessing
model = NestedWrappedModule(group, wrapper_config=config)
return FullyShardedDataParallel(model, group, **config)
class TestSerialization(DistributedTest): class TestSerialization(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True], [False, True]], name_func=rename_test) @parameterized.expand([[False, False], [True, False], [True, True], [False, True]], name_func=rename_test)
def test_pickle(self, mixed_precision, cpu_offload): def test_pickle(self, mixed_precision, cpu_offload):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment