Unverified Commit a825348d authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[FSDP][feature] Support returning the original parameter names after a model...

[FSDP][feature] Support returning the original parameter names after a model has been wrapped with FSDP (#755)

* checkpoint work

* fix lint issues

* remove debug statement

* remove print

* fix lint errors

* fix lint errors

* fix lint errors

* add comments and fix lint errors

* modified comments and tests
parent 31d600cc
...@@ -18,6 +18,7 @@ from typing import ( ...@@ -18,6 +18,7 @@ from typing import (
Callable, Callable,
Dict, Dict,
Generator, Generator,
Iterator,
List, List,
Mapping, Mapping,
NamedTuple, NamedTuple,
...@@ -33,8 +34,8 @@ from torch.autograd import Variable ...@@ -33,8 +34,8 @@ from torch.autograd import Variable
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
import torch.nn as nn import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
...@@ -659,6 +660,29 @@ class FullyShardedDataParallel(nn.Module): ...@@ -659,6 +660,29 @@ class FullyShardedDataParallel(nn.Module):
del self.orig_sizes del self.orig_sizes
self._reset_lazy_init() self._reset_lazy_init()
def named_parameters(self, *args: Any, **kwargs: Any) -> Iterator[Tuple[str, Parameter]]:
"""Returns an iterator over the module parameters, yielding both the name of the
parameter as well as the parameter.
With FSDP, the `named_parameters` function implemented in `nn.Module` will not
be able to return the name and param when we use flattened parameters unless
we call this function under a `summon_full_params` context.
If you want the full param to be returned, you should call this function
under a `summon_full_params` context when using flattened or original params.
"""
named_param = super().named_parameters(*args, **kwargs)
for name, param in named_param:
if (
hasattr(self, "flatten_parameters")
and self.flatten_parameters
and hasattr(self, "training_state")
and self.training_state != TrainingState.SUMMON_FULL_PARAMS
):
yield name, param
else:
yield _clean_path(name), param
def __getitem__(self, key: int) -> Any: def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential.""" """Forward indexing calls in case the module is a nn.Sequential."""
return self.module.__getitem__(key) return self.module.__getitem__(key)
...@@ -2001,7 +2025,7 @@ def _post_state_dict_hook( ...@@ -2001,7 +2025,7 @@ def _post_state_dict_hook(
# each tensor so that it does not get freed (in-place) when the context # each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times # exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at # recursively, so we need to make sure that we only clone each tensor at
# mostonce. Thus we add an attribute on the tensor called "_has_been_cloned" # most once. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed. # which keeps track of tensors that are no longer at risk of being freed.
for key in state_dict.keys(): for key in state_dict.keys():
if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False): if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False):
......
...@@ -533,6 +533,41 @@ class TestNoGrad(DistributedTest): ...@@ -533,6 +533,41 @@ class TestNoGrad(DistributedTest):
assert objects_are_equal(ref_output, no_grad_output, raise_exception=True) assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)
class TestModuleProperties(DistributedTest):
@parameterized.expand([[{"flatten_parameters": False}], [{"flatten_parameters": True}]], name_func=rename_test)
def test_named_parameters(self, config):
test_fn = functools.partial(self._test_named_params, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_named_params(self, rank, group, config):
# Get the named parameters before wrapping.
before_wrap_model = TransformerWithSharedParams(group)
before_wrap_params = before_wrap_model.named_parameters()
# Train the model for 1 step.
model = self.get_wrapped_model(group, cuda_first=False, config=config)
self._train_for_several_steps(model, 1, autocast=False)
# Get the named parameters after wrapping to compare.
after_wrap_params = model.named_parameters()
if not config["flatten_parameters"]:
for before_nm, after_nm in zip(before_wrap_params, after_wrap_params):
assert before_nm[0] == after_nm[0]
else:
named_params_flat = [p for p in after_wrap_params][0][0]
assert "flat_param_0" in named_params_flat
# Compare name and size under the `summon_full_params` context.
with model.summon_full_params():
after_wrap_params = model.named_parameters()
for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params):
assert before_nm[0] == after_nm_original[0]
torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].cpu().shape)
class TransformerWithSharedParams(nn.Module): class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs): def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment