Unverified Commit d3417ceb authored by four4fish's avatar four4fish Committed by GitHub
Browse files

FullyShardedDataParallel: only return full state dict on rank 0 (#885)

* FullyShardedDataParallel: only return full state dict on rank 0

* Add flag and make rank 0 only optional

* Add tests

* Add docs

* address comments

* update comments

* update torch nightly version

* update torchvision number for torch nightly dependence

* add changelog

* Update CHANGELOG.md

* Update CHANGELOG.md
parent c5e471bc
...@@ -117,7 +117,7 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly ...@@ -117,7 +117,7 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi
# start installing # start installing
pip install --progress-bar off --pre torch==1.11.0.dev20211101+cu111 torchvision==0.12.0.dev20211101+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html pip install --progress-bar off --pre torch==1.11.0.dev20211231+cu111 torchvision==0.12.0.dev20211231+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --progress-bar off -r requirements-dev.txt pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
......
...@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.4.5] - TBD ## [0.4.5] - TBD
### Added ### Added
- FSDP: Added state_dict_on_rank_0_only flag allow user choose to return full state dict on rank 0 and return empty dict non-rank 0 to prevent OOM [#844]
### Changed ### Changed
......
...@@ -279,6 +279,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -279,6 +279,11 @@ class FullyShardedDataParallel(nn.Module):
The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and
other required knobs when offloading parameters from GPU. Currently the OffloadConfig other required knobs when offloading parameters from GPU. Currently the OffloadConfig
only supports specifying SSD offload as an option. Note: This is an experimental feature. only supports specifying SSD offload as an option. Note: This is an experimental feature.
state_dict_on_rank_0_only (bool):
When set to ``True``, ``model.state_dict()`` will only returns full state dict on
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False
""" """
def __init__( def __init__(
...@@ -302,6 +307,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -302,6 +307,7 @@ class FullyShardedDataParallel(nn.Module):
verbose: bool = False, verbose: bool = False,
cpu_offload: bool = False, cpu_offload: bool = False,
offload_config: Optional[OffloadConfig] = None, offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
): ):
init_start = time.time() init_start = time.time()
super().__init__() super().__init__()
...@@ -324,6 +330,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -324,6 +330,7 @@ class FullyShardedDataParallel(nn.Module):
self.clear_autocast_cache = clear_autocast_cache self.clear_autocast_cache = clear_autocast_cache
self.force_input_to_fp32 = force_input_to_fp32 self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose self.verbose = verbose
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk. # Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
...@@ -418,7 +425,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -418,7 +425,7 @@ class FullyShardedDataParallel(nn.Module):
# Register hook after state_dict() to remove the "_fsdp_wrapped_module." # Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back. # prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook) self._register_state_dict_hook(functools.partial(_post_state_dict_hook, self.state_dict_on_rank_0_only))
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook) self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
# Flag to indicate whether state_dict() should automatically summon the # Flag to indicate whether state_dict() should automatically summon the
...@@ -2353,8 +2360,19 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None: ...@@ -2353,8 +2360,19 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
def _post_state_dict_hook( def _post_state_dict_hook(
module: FullyShardedDataParallel, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any state_dict_on_rank_0_only: bool,
module: FullyShardedDataParallel,
state_dict: "OrderedDict[str, torch.Tensor]",
prefix: str,
*args: Any,
) -> "OrderedDict[str, torch.Tensor]": ) -> "OrderedDict[str, torch.Tensor]":
# When state_dict_on_rank_0_only is ``True``, ``model.state_dict()`` will only
# returns full state dict on rank 0 and return empty dict non-rank 0,
# which allow FullyShardedDataParallel to skip the GPU -> CPU copy on
# non-rank 0 altogether and prevent OOM.
if state_dict_on_rank_0_only and dist.get_rank() != 0:
state_dict.clear()
return state_dict
# Assuming we are in a ``summon_full_params()`` context, we need to clone # Assuming we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context # each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times # exits. At the same time, this hook can be called multiple times
......
...@@ -138,6 +138,13 @@ class DistributedTest(unittest.TestCase): ...@@ -138,6 +138,13 @@ class DistributedTest(unittest.TestCase):
shard_loss = shard_loss.cuda() shard_loss = shard_loss.cuda()
shard_state_dict = model.state_dict() shard_state_dict = model.state_dict()
if config.get("state_dict_on_rank_0_only", False):
if torch.distributed.get_rank() != 0:
assert shard_state_dict == {}
# rank 0 shard_state_dict test covered in the following test.
# return is needed here, because with state_dict_on_rank_0_only=True, the following assert will fail on rank!=0
return
try: try:
torch.testing.assert_allclose(ref_loss, shard_loss) torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
...@@ -361,6 +368,13 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -361,6 +368,13 @@ class TestComparisonToPyTorchDDP(DistributedTest):
test_fn = functools.partial(self._test_identical_outputs, model_fn, config) test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand([[True], [False]], name_func=rename_test)
def test_state_dict_on_rank_0_only(self, state_dict_on_rank_0_only):
config = {"state_dict_on_rank_0_only": state_dict_on_rank_0_only}
model_fn = functools.partial(TransformerWithSharedParams)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
@parameterized.expand([[{"checkpoint_act": False}], [{"checkpoint_act": True}]], name_func=rename_test) @parameterized.expand([[{"checkpoint_act": False}], [{"checkpoint_act": True}]], name_func=rename_test)
def test_mixture_of_experts(self, moe_config): def test_mixture_of_experts(self, moe_config):
fsdp_config = {"mixed_precision": True} fsdp_config = {"mixed_precision": True}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment