Unverified Commit d3417ceb authored by four4fish's avatar four4fish Committed by GitHub
Browse files

FullyShardedDataParallel: only return full state dict on rank 0 (#885)

* FullyShardedDataParallel: only return full state dict on rank 0

* Add flag and make rank 0 only optional

* Add tests

* Add docs

* address comments

* update comments

* update torch nightly version

* update torchvision number for torch nightly dependence

* add changelog

* Update CHANGELOG.md

* Update CHANGELOG.md
parent c5e471bc
......@@ -117,7 +117,7 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi
# start installing
pip install --progress-bar off --pre torch==1.11.0.dev20211101+cu111 torchvision==0.12.0.dev20211101+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --progress-bar off --pre torch==1.11.0.dev20211231+cu111 torchvision==0.12.0.dev20211231+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
......
......@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.4.5] - TBD
### Added
- FSDP: Added state_dict_on_rank_0_only flag allow user choose to return full state dict on rank 0 and return empty dict non-rank 0 to prevent OOM [#844]
### Changed
......
......@@ -279,6 +279,11 @@ class FullyShardedDataParallel(nn.Module):
The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and
other required knobs when offloading parameters from GPU. Currently the OffloadConfig
only supports specifying SSD offload as an option. Note: This is an experimental feature.
state_dict_on_rank_0_only (bool):
When set to ``True``, ``model.state_dict()`` will only returns full state dict on
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False
"""
def __init__(
......@@ -302,6 +307,7 @@ class FullyShardedDataParallel(nn.Module):
verbose: bool = False,
cpu_offload: bool = False,
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
):
init_start = time.time()
super().__init__()
......@@ -324,6 +330,7 @@ class FullyShardedDataParallel(nn.Module):
self.clear_autocast_cache = clear_autocast_cache
self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
......@@ -418,7 +425,7 @@ class FullyShardedDataParallel(nn.Module):
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook)
self._register_state_dict_hook(functools.partial(_post_state_dict_hook, self.state_dict_on_rank_0_only))
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
# Flag to indicate whether state_dict() should automatically summon the
......@@ -2353,8 +2360,19 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
def _post_state_dict_hook(
module: FullyShardedDataParallel, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any
state_dict_on_rank_0_only: bool,
module: FullyShardedDataParallel,
state_dict: "OrderedDict[str, torch.Tensor]",
prefix: str,
*args: Any,
) -> "OrderedDict[str, torch.Tensor]":
# When state_dict_on_rank_0_only is ``True``, ``model.state_dict()`` will only
# returns full state dict on rank 0 and return empty dict non-rank 0,
# which allow FullyShardedDataParallel to skip the GPU -> CPU copy on
# non-rank 0 altogether and prevent OOM.
if state_dict_on_rank_0_only and dist.get_rank() != 0:
state_dict.clear()
return state_dict
# Assuming we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
......
......@@ -138,6 +138,13 @@ class DistributedTest(unittest.TestCase):
shard_loss = shard_loss.cuda()
shard_state_dict = model.state_dict()
if config.get("state_dict_on_rank_0_only", False):
if torch.distributed.get_rank() != 0:
assert shard_state_dict == {}
# rank 0 shard_state_dict test covered in the following test.
# return is needed here, because with state_dict_on_rank_0_only=True, the following assert will fail on rank!=0
return
try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
......@@ -361,6 +368,13 @@ class TestComparisonToPyTorchDDP(DistributedTest):
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
@parameterized.expand([[True], [False]], name_func=rename_test)
def test_state_dict_on_rank_0_only(self, state_dict_on_rank_0_only):
config = {"state_dict_on_rank_0_only": state_dict_on_rank_0_only}
model_fn = functools.partial(TransformerWithSharedParams)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
@parameterized.expand([[{"checkpoint_act": False}], [{"checkpoint_act": True}]], name_func=rename_test)
def test_mixture_of_experts(self, moe_config):
fsdp_config = {"mixed_precision": True}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment