Unverified Commit a825348d authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[FSDP][feature] Support returning the original parameter names after a model...

[FSDP][feature] Support returning the original parameter names after a model has been wrapped with FSDP (#755)

* checkpoint work

* fix lint issues

* remove debug statement

* remove print

* fix lint errors

* fix lint errors

* fix lint errors

* add comments and fix lint errors

* modified comments and tests
parent 31d600cc
......@@ -18,6 +18,7 @@ from typing import (
Callable,
Dict,
Generator,
Iterator,
List,
Mapping,
NamedTuple,
......@@ -33,8 +34,8 @@ from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
......@@ -659,6 +660,29 @@ class FullyShardedDataParallel(nn.Module):
del self.orig_sizes
self._reset_lazy_init()
def named_parameters(self, *args: Any, **kwargs: Any) -> Iterator[Tuple[str, Parameter]]:
"""Returns an iterator over the module parameters, yielding both the name of the
parameter as well as the parameter.
With FSDP, the `named_parameters` function implemented in `nn.Module` will not
be able to return the name and param when we use flattened parameters unless
we call this function under a `summon_full_params` context.
If you want the full param to be returned, you should call this function
under a `summon_full_params` context when using flattened or original params.
"""
named_param = super().named_parameters(*args, **kwargs)
for name, param in named_param:
if (
hasattr(self, "flatten_parameters")
and self.flatten_parameters
and hasattr(self, "training_state")
and self.training_state != TrainingState.SUMMON_FULL_PARAMS
):
yield name, param
else:
yield _clean_path(name), param
def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self.module.__getitem__(key)
......@@ -2001,7 +2025,7 @@ def _post_state_dict_hook(
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# mostonce. Thus we add an attribute on the tensor called "_has_been_cloned"
# most once. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
for key in state_dict.keys():
if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False):
......
......@@ -533,6 +533,41 @@ class TestNoGrad(DistributedTest):
assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)
class TestModuleProperties(DistributedTest):
@parameterized.expand([[{"flatten_parameters": False}], [{"flatten_parameters": True}]], name_func=rename_test)
def test_named_parameters(self, config):
test_fn = functools.partial(self._test_named_params, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_named_params(self, rank, group, config):
# Get the named parameters before wrapping.
before_wrap_model = TransformerWithSharedParams(group)
before_wrap_params = before_wrap_model.named_parameters()
# Train the model for 1 step.
model = self.get_wrapped_model(group, cuda_first=False, config=config)
self._train_for_several_steps(model, 1, autocast=False)
# Get the named parameters after wrapping to compare.
after_wrap_params = model.named_parameters()
if not config["flatten_parameters"]:
for before_nm, after_nm in zip(before_wrap_params, after_wrap_params):
assert before_nm[0] == after_nm[0]
else:
named_params_flat = [p for p in after_wrap_params][0][0]
assert "flat_param_0" in named_params_flat
# Compare name and size under the `summon_full_params` context.
with model.summon_full_params():
after_wrap_params = model.named_parameters()
for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params):
assert before_nm[0] == after_nm_original[0]
torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].cpu().shape)
class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment