Unverified Commit b6dc98cf authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] fix FSDP state_dict/load_state_dict for nested wrapped instances (#440)

parent 93d115c6
......@@ -29,6 +29,7 @@ from fairscale.utils.containers import (
)
from fairscale.utils.parallel import chunk_and_pad, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
......@@ -172,11 +173,11 @@ class FullyShardedDataParallel(nn.Module):
params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))
if self.flatten_parameters and len(params) > 0:
self.module: nn.Module = FlattenParamsWrapper(module, param_list=params)
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
del module # free original module in case it helps garbage collection
self.params = [self.module.flat_param]
self.params = [self._fsdp_wrapped_module.flat_param]
else:
self.module = module
self._fsdp_wrapped_module = module
self.params = params
# Shard module parameters in place
......@@ -192,8 +193,23 @@ class FullyShardedDataParallel(nn.Module):
# pass. This will be False when inside the no_sync context manager.
self.require_backward_grad_sync: bool = True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook)
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
self._return_full_state_dict = True
@property
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance
@torch.no_grad()
def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
"""Move all buffers to the specified device and dtype, recursively."""
......@@ -235,6 +251,7 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
self._lazy_init()
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
assert self.training_state == TrainingState.IDLE
......@@ -374,7 +391,7 @@ class FullyShardedDataParallel(nn.Module):
self._reset_lazy_init()
# TODO (Min): figuring out how to do typing for this overloaded function.
def state_dict(self, *args, **kwargs): # type: ignore
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tensor]": # type: ignore
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
......@@ -384,16 +401,28 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
with self.summon_full_params():
if self.mixed_precision:
# Buffers dtype stays consistent with parameters.
self._all_buffers_to(dtype=torch.float32)
state_dict = self.module.state_dict(*args, **kwargs)
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
for key in state_dict.keys():
state_dict[key] = state_dict[key].clone()
if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
with self.summon_full_params():
state_dict = super().state_dict(*args, **kwargs)
else:
torch.cuda.synchronize()
self._lazy_init()
state_dict = super().state_dict(*args, **kwargs)
else:
torch.cuda.synchronize()
self._lazy_init()
if self.flatten_parameters:
assert isinstance(self.module, FlattenParamsWrapper)
state_dict = self.module.flat_state_dict(*args, **kwargs)
else:
state_dict = super().state_dict(*args, **kwargs)
if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.compute_dtype)
return state_dict
......@@ -405,12 +434,21 @@ class FullyShardedDataParallel(nn.Module):
so the resulting state_dict can only be loaded after the Module has been
wrapped with FullyShardedDataParallel.
"""
torch.cuda.synchronize()
self._lazy_init()
if self.flatten_parameters:
return self.module.flat_state_dict(*args, **kwargs) # type: ignore
else:
return self.module.state_dict(*args, **kwargs)
with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict())
return self.state_dict(*args, **kwargs)
@contextlib.contextmanager
def _no_return_full_state_dict(self) -> Generator:
backup = self._return_full_state_dict
self._return_full_state_dict = False
try:
yield
finally:
self._return_full_state_dict = backup
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
......@@ -421,16 +459,25 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
with self.summon_full_params():
output = self.module.load_state_dict(state_dict, strict)
return output
if self._return_full_state_dict:
with self.summon_full_params():
return self.module.load_state_dict(state_dict, strict)
else:
torch.cuda.synchronize()
self._lazy_init()
return self.module.load_state_dict(state_dict, strict)
def load_local_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
"""Load a local (sharded) state_dict."""
torch.cuda.synchronize()
return self.module.load_state_dict(state_dict, strict)
with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict())
output = self.load_state_dict(state_dict, strict)
return output
@contextlib.contextmanager
def no_sync(self) -> Generator:
......@@ -457,30 +504,44 @@ class FullyShardedDataParallel(nn.Module):
m.require_backward_grad_sync = old_flag
@contextlib.contextmanager
def summon_full_params(self) -> Generator:
def summon_full_params(self, recurse: bool = True) -> Generator:
"""
A context manager to expose full params for the underlying model.
Can be useful *after* forward/backward for a model to get the params
for additional processing or checking.
A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params for
additional processing or checking.
By default this will recursively summon all params for nested FSDP
instances; this can be disabled by setting ``recurse=False``.
This can be used on inner FSDPs.
.. note:: This can be used on inner FSDPs.
This can *not* be used within a forward or backward pass. Nor can forward
and backward be started from within this context.
.. note:: This can *not* be used within a forward or backward pass. Nor
can forward and backward be started from within this context.
"""
torch.cuda.synchronize()
self._lazy_init()
self.assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
self._rebuild_full_params()
try:
yield
finally:
self._free_full_params()
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
if recurse:
with contextlib.ExitStack() as stack:
# summon all params for any nested FlattenParamsWrapper instances
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module.summon_full_params(recurse=False))
# yield to the caller, with full params in all nested instances
yield
# exiting from the ExitStack will re-shard params
return
else:
torch.cuda.synchronize()
self._lazy_init()
self.assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
self._rebuild_full_params()
try:
yield
finally:
self._free_full_params()
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
......@@ -1011,3 +1072,23 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
return
assert data.storage().size() == 0
data.storage().resize_(size.numel())
def _post_state_dict_hook(
module: nn.Module, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any
) -> "OrderedDict[str, torch.Tensor]":
if module.training_state == TrainingState.SUMMON_FULL_PARAMS:
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
for key in state_dict.keys():
state_dict[key] = state_dict[key].clone()
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix)
return state_dict
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
......@@ -415,7 +415,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
assert shape_dtype_device_match
return True
except AssertionError as e:
except (AssertionError, RuntimeError) as e:
if raise_exception:
raise e
else:
......
......@@ -434,11 +434,9 @@ class TestSaveLoadStateDict(DistributedTest):
ddp_model.state_dict()
ddp_model.state_dict() # second call
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_state_dict_after_forward_mixed_precision(self, mixed_precision):
test_fn = functools.partial(
self._test_module_state_dict, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_state_dict_after_forward(self, config):
test_fn = functools.partial(self._test_module_state_dict, config)
spawn_and_init(test_fn)
@parameterized.expand([[False], [True]], name_func=rename_test)
......@@ -474,6 +472,62 @@ class TestSaveLoadStateDict(DistributedTest):
except Exception:
pass
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model(self, config):
if config["mixed_precision"]:
return # TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn = functools.partial(self._test_nested_wrapped_model, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model_local_state_dict(self, config):
if config["mixed_precision"]:
return # TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn = functools.partial(self._test_nested_wrapped_model_local_state_dict, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_nested_wrapped_model(cls, rank, group, config=None):
# Get reference state dict without any nested FSDP instances.
model = NestedWrappedModule(group, None).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group)
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()}
# Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round-trip state dict save/load/save.
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
model.load_state_dict(state_dict)
state_dict = model.state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
@classmethod
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None):
# Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round trip state dict save/load/save.
ref_state_dict = {k: v.clone() for k, v in model.local_state_dict().items()}
model.load_local_state_dict(ref_state_dict)
state_dict = model.local_state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes.
......@@ -689,7 +743,10 @@ class NestedWrappedModule(nn.Module):
torch.manual_seed(0) # keep everything deterministic
self.module = nn.Sequential(
nn.Linear(8, 4), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8),
nn.Linear(8, 4),
_maybe_wrap(nn.Sequential(_maybe_wrap(nn.Linear(4, 16)), nn.Linear(16, 16),)),
_maybe_wrap(nn.Linear(16, 4)),
nn.Linear(4, 8),
)
def get_input(self, device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment