Unverified Commit 4b126c7b authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] support a context for loading state_dict for FSDP (#1065)



* [fix]: add a context for supporting state_dict from a non-FSDP parent module

* formatting
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 3cc7fa8d
......@@ -5,7 +5,13 @@
from typing import List
from .fully_sharded_data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState, auto_wrap_bn
from .fully_sharded_data_parallel import (
FullyShardedDataParallel,
OffloadConfig,
TrainingState,
auto_wrap_bn,
no_pre_load_state_dict_hook,
)
from .sharded_ddp import ShardedDataParallel
__all__: List[str] = []
......@@ -51,7 +51,7 @@ from fairscale.internal.parallel import (
from fairscale.internal.params import calc_grad_norm, recursive_copy_to_device
from fairscale.internal.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.internal.state_dict import replace_by_prefix_
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.misc import FlattenParamsWrapper, _enable_pre_load_state_dict_hook
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from . import fsdp_optim_utils as ou
......@@ -762,6 +762,7 @@ class FullyShardedDataParallel(nn.Module):
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True
# TODO (Min): broadcast from rank 0 to avoid each rank need to init with the same seed?
# Replace p.data with the relevant shard.
orig_data = p.data
......@@ -2581,9 +2582,24 @@ def _post_state_dict_hook(
return state_dict
@contextlib.contextmanager
def no_pre_load_state_dict_hook() -> Generator:
"""Disable the pre-load hook.
This is needed if we are loading a state_dict that was not produced by
a root FSDP instance.
"""
global _enable_pre_load_state_dict_hook
bak = _enable_pre_load_state_dict_hook
_enable_pre_load_state_dict_hook = False
yield
_enable_pre_load_state_dict_hook = bak
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
if _enable_pre_load_state_dict_hook:
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
......
......@@ -9,7 +9,7 @@ from typing import List
# in favor of fairscale.nn.checkpoint.checkpoint_wrapper.
from fairscale.nn.checkpoint import checkpoint_wrapper
from .flatten_params_wrapper import FlattenParamsWrapper
from .flatten_params_wrapper import FlattenParamsWrapper, _enable_pre_load_state_dict_hook
from .param_bucket import GradBucket, ParamBucket
__all__: List[str] = []
......@@ -49,6 +49,9 @@ from fairscale.internal.state_dict import replace_by_prefix_
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
# See no_pre_load_state_dict_hook context manager function in FSDP for more details.
_enable_pre_load_state_dict_hook = True
class FlatParameter(nn.Parameter):
"""A parameter that is initialized from a list of parameters and can be
......@@ -543,6 +546,8 @@ def _post_state_dict_hook(
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], prefix: str, *args: Any
) -> None:
if not _enable_pre_load_state_dict_hook:
return
# Push everything down to ._fpw_module level.
replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.")
# The flat_param_* keys actually needs to move one level up.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment