Unverified Commit 506d6209 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] Fix nested FlattenParamsWrapper state_dict/load_state_dict (#434)

parent 9163e381
# Copyright (c) Tongzhou Wang # Copyright (c) Tongzhou Wang
# Licensed under the MIT License. # Licensed under the MIT License.
from contextlib import contextmanager from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn as nn
from fairscale.utils.state_dict import replace_by_prefix_
if TYPE_CHECKING: if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401 from collections import OrderedDict # noqa: F401
...@@ -32,7 +34,8 @@ class FlattenParamsWrapper(nn.Module): ...@@ -32,7 +34,8 @@ class FlattenParamsWrapper(nn.Module):
def __init__(self, module: nn.Module, param_list: Optional[List[nn.Parameter]] = None): def __init__(self, module: nn.Module, param_list: Optional[List[nn.Parameter]] = None):
super().__init__() super().__init__()
self.module = module self._fpw_module = module
self.is_flattened = False
if param_list is not None: if param_list is not None:
assert len(param_list) > 0, "param_list can't be empty" assert len(param_list) > 0, "param_list can't be empty"
...@@ -53,7 +56,24 @@ class FlattenParamsWrapper(nn.Module): ...@@ -53,7 +56,24 @@ class FlattenParamsWrapper(nn.Module):
# register the views as plain attributes # register the views as plain attributes
self._unflatten_params_as_views() self._unflatten_params_as_views()
# Register hook to be called after state_dict() to remove the
# "_fpw_module." prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook)
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
# Flag to indicate whether state_dict() should automatically unflatten
# params. This defaults to True, but may be set to False if the user
# explicitly requests a flat state dict via flat_state_dict().
self._auto_unflatten_state_dict = True
@property
def module(self) -> nn.Module:
return self._fpw_module
def _flatten_params(self) -> None: def _flatten_params(self) -> None:
assert not self.is_flattened
self.is_flattened = True
param_infos = [] param_infos = []
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {} shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {}
shared_param_infos = [] shared_param_infos = []
...@@ -98,6 +118,9 @@ class FlattenParamsWrapper(nn.Module): ...@@ -98,6 +118,9 @@ class FlattenParamsWrapper(nn.Module):
return (t.view(s) for (t, s) in zip(self.flat_param.split(self._param_numels), self._param_shapes)) return (t.view(s) for (t, s) in zip(self.flat_param.split(self._param_numels), self._param_shapes))
def _unflatten_params(self) -> None: def _unflatten_params(self) -> None:
assert self.is_flattened
self.is_flattened = False
ps = self._get_param_views() ps = self._get_param_views()
for (m, n), p in zip(self._param_infos, ps): for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n): if hasattr(m, n):
...@@ -110,6 +133,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -110,6 +133,7 @@ class FlattenParamsWrapper(nn.Module):
del self.flat_param del self.flat_param
def _unflatten_params_as_views(self) -> None: def _unflatten_params_as_views(self) -> None:
assert self.is_flattened
ps = self._get_param_views() ps = self._get_param_views()
for (m, n), p in zip(self._param_infos, ps): for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr setattr(m, n, p) # This will set as plain attr
...@@ -117,11 +141,30 @@ class FlattenParamsWrapper(nn.Module): ...@@ -117,11 +141,30 @@ class FlattenParamsWrapper(nn.Module):
setattr(m, n, getattr(shared_m, shared_n)) setattr(m, n, getattr(shared_m, shared_n))
@contextmanager @contextmanager
def unflatten_params(self) -> Generator: def unflatten_params(self, recurse: bool = True) -> Generator:
self._unflatten_params() """
yield Unflatten params (optionally recursively on all nested instances).
self._flatten_params() If the current instance is already unflattened, then it will remain
self._unflatten_params_as_views() unflattened after the context manager exits.
"""
if recurse:
with ExitStack() as stack:
# unflatten any nested FlattenParamsWrapper instances
for module in self.modules():
if isinstance(module, FlattenParamsWrapper):
stack.enter_context(module.unflatten_params(recurse=False))
# yield to the caller, with unflattened params in all nested instances
yield
# exiting from the ExitStack will re-flatten params
return
else:
orig_flattened = self.is_flattened
if self.is_flattened:
self._unflatten_params()
yield
if orig_flattened:
self._flatten_params()
self._unflatten_params_as_views()
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module.""" """Forward missing attributes to wrapped module."""
...@@ -131,23 +174,62 @@ class FlattenParamsWrapper(nn.Module): ...@@ -131,23 +174,62 @@ class FlattenParamsWrapper(nn.Module):
return getattr(self.module, name) # fallback to wrapped module return getattr(self.module, name) # fallback to wrapped module
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore
"""Return an unflattened state_dict.""" """Return the wrapped module's state_dict (unflattened)."""
with self.unflatten_params(): if self.is_flattened and self._auto_unflatten_state_dict:
return self.module.state_dict(*args, **kwargs) with self.unflatten_params(recurse=False):
return super().state_dict(*args, **kwargs)
else:
return super().state_dict(*args, **kwargs)
def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Return the flattened state_dict.""" """Return the flattened state_dict."""
return super().state_dict(*args, **kwargs) assert self.is_flattened
with ExitStack() as stack:
# tell any nested FlattenParamsWrapper instances not to auto unflatten
for module in self.modules(): # includes self
if isinstance(module, FlattenParamsWrapper):
stack.enter_context(module._no_auto_unflatten_state_dict())
state_dict = self.state_dict(*args, **kwargs)
return state_dict
@contextmanager
def _no_auto_unflatten_state_dict(self) -> Generator:
backup = self._auto_unflatten_state_dict
self._auto_unflatten_state_dict = False
yield
self._auto_unflatten_state_dict = backup
def load_state_dict( def load_state_dict(
self, state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], strict: bool = True self, state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], strict: bool = True
) -> NamedTuple: ) -> NamedTuple:
if "flat_param" in state_dict: """
return super().load_state_dict(state_dict, strict=strict) Load a state dict. If necessary, ``unflatten_params`` will be called to
match the input state_dict.
"""
# unflatten the module automatically if the state_dict is non-flat
if self.is_flattened and "flat_param" not in state_dict:
with self.unflatten_params(recurse=True):
return super().load_state_dict(state_dict, strict)
else: else:
with self.unflatten_params(): return super().load_state_dict(state_dict, strict)
return self.module.load_state_dict(state_dict, strict)
def forward(self, *inputs: Any, **kwinputs: Any) -> Any: def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_as_views() self._unflatten_params_as_views()
return self.module(*inputs, **kwinputs) return self.module(*inputs, **kwinputs)
def _post_state_dict_hook(
module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any
) -> "OrderedDict[str, Tensor]":
replace_by_prefix_(state_dict, prefix + "_fpw_module.", prefix)
return state_dict
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], prefix: str, *args: Any
) -> None:
replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.")
# flat_param actually needs to move one level up though
flat_param_key = prefix + "_fpw_module.flat_param"
if flat_param_key in state_dict:
replace_by_prefix_(state_dict, flat_param_key, prefix + "flat_param")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Useful functions for manipulating state_dicts."""
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
from torch import Tensor, nn
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
def find_module_instances(module: nn.Module, search_class: Type[nn.Module]) -> List[Tuple[str, nn.Module]]:
"""
Find all occurrences of a given search_class among the given Modules's
children and return the corresponding paths in the same format as
state_dicts.
Usage::
net = nn.Sequential(
nn.Linear(1, 1),
nn.ModuleDict({"ln": nn.LayerNorm(1), "linear": nn.Linear(1, 1)}),
nn.LayerNorm(1)
)
>>> find_module_instances(net, nn.LayerNorm)
[('1.ln.', LayerNorm((1,), eps=1e-05, elementwise_affine=True)), ('2.', LayerNorm((1,), eps=1e-05, elementwise_affine=True))]
>>> find_module_instances(net, nn.Dropout)
[]
>>> find_module_instances(net, nn.Sequential)
[('', Sequential(
(0): Linear(in_features=1, out_features=1, bias=True)
(1): ModuleDict(
(ln): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
(linear): Linear(in_features=1, out_features=1, bias=True)
)
(2): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
))]
"""
paths = []
def add_paths_(module: nn.Module, prefix: str = "") -> None:
if isinstance(module, search_class):
paths.append((prefix, module))
for name, child in module.named_children():
add_paths_(child, prefix + name + ".")
add_paths_(module)
return paths
def replace_by_prefix_(
state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], old_prefix: str, new_prefix: str
) -> None:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_by_prefix_(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if old_prefix == new_prefix:
raise ValueError("old_prefix and new_prefix must be distinct")
for key in list(state_dict.keys()):
if not key.startswith(old_prefix):
continue
new_key = new_prefix + key[len(old_prefix) :]
state_dict[new_key] = state_dict[key]
del state_dict[key]
...@@ -28,6 +28,7 @@ from . import cuda as cuda ...@@ -28,6 +28,7 @@ from . import cuda as cuda
from . import optim as optim from . import optim as optim
from . import nn as nn from . import nn as nn
from . import testing as testing from . import testing as testing
from . import utils as utils
#MODIFIED BY TORCHGPIPE #MODIFIED BY TORCHGPIPE
from . import backends from . import backends
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from typing import Optional, Tuple, Union, Dict, Any from typing import Optional, Tuple, Union, Dict, Any
import ctypes import ctypes
from . import amp
from .. import device as _device from .. import device as _device
def is_available() -> bool: ... def is_available() -> bool: ...
......
...@@ -409,7 +409,8 @@ class TestLocalStateDict(DistributedTest): ...@@ -409,7 +409,8 @@ class TestLocalStateDict(DistributedTest):
# Assert that parameters were updated since before training # Assert that parameters were updated since before training
unchanged = [] unchanged = []
buffers = {name for name, _ in model.module.named_buffers()} unwrapped_model = model.module.module if config["flatten_parameters"] else model.module
buffers = {name for name, _ in unwrapped_model.named_buffers()}
for k in state_1: for k in state_1:
if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers): if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers):
unchanged.append(k) unchanged.append(k)
......
...@@ -16,12 +16,26 @@ from fairscale.utils.testing import objects_are_equal ...@@ -16,12 +16,26 @@ from fairscale.utils.testing import objects_are_equal
class TestFlattenParams(unittest.TestCase): class TestFlattenParams(unittest.TestCase):
def _get_module_init_fns(self):
return [
self._get_shared_params_transformer,
self._get_nested_flat_module,
]
def _get_transformer(self, seed=0): def _get_transformer(self, seed=0):
torch.manual_seed(seed) # keep everything deterministic torch.manual_seed(seed) # keep everything deterministic
module = torch.nn.Transformer( module = torch.nn.Transformer(
d_model=32, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=128, dropout=0.1, d_model=32, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=128, dropout=0.1,
) )
module.register_buffer("dummy_buffer", torch.tensor(1.0)) module.register_buffer("dummy_buffer", torch.tensor(1.0))
def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
src = torch.rand(20, 8, 32).to(device=device, dtype=dtype) # T x B x C
tgt = torch.rand(10, 8, 32).to(device=device, dtype=dtype) # T x B x C
return (src, tgt)
module.get_input = get_input
return module return module
def _get_shared_params_transformer(self, seed=0): def _get_shared_params_transformer(self, seed=0):
...@@ -32,13 +46,27 @@ class TestFlattenParams(unittest.TestCase): ...@@ -32,13 +46,27 @@ class TestFlattenParams(unittest.TestCase):
dec_layer.linear2.weight = enc_layer.linear2.weight dec_layer.linear2.weight = enc_layer.linear2.weight
return module return module
def _get_nested_flat_module(self, seed=0):
module = torch.nn.Sequential(
FlattenParamsWrapper(
torch.nn.Sequential(torch.nn.Linear(4, 8), FlattenParamsWrapper(torch.nn.Linear(8, 8)))
),
FlattenParamsWrapper(torch.nn.Sequential(FlattenParamsWrapper(torch.nn.Linear(8, 16)))),
FlattenParamsWrapper(torch.nn.Linear(16, 4)),
)
def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
return (torch.rand(8, 4).to(device=device, dtype=dtype),)
module.get_input = get_input
return module
def _get_output(self, module): def _get_output(self, module):
torch.manual_seed(1) # keep everything deterministic
device = next(module.parameters()).device device = next(module.parameters()).device
dtype = next(module.parameters()).dtype dtype = next(module.parameters()).dtype
src = torch.rand(20, 8, 32).to(device=device, dtype=dtype) # T x B x C input = module.get_input(device, dtype)
tgt = torch.rand(10, 8, 32).to(device=device, dtype=dtype) # T x B x C return module(*input)
return module(src, tgt)
def _get_pnorm_after_step(self, module): def _get_pnorm_after_step(self, module):
optim = torch.optim.SGD(module.parameters(), lr=0.01) optim = torch.optim.SGD(module.parameters(), lr=0.01)
...@@ -120,39 +148,53 @@ class TestFlattenParams(unittest.TestCase): ...@@ -120,39 +148,53 @@ class TestFlattenParams(unittest.TestCase):
torch.testing.assert_allclose(ref_pnorm_after_step, flat_pnorm_after_step) torch.testing.assert_allclose(ref_pnorm_after_step, flat_pnorm_after_step)
def test_state_dict_equality(self): def test_state_dict_equality(self):
module = self._get_shared_params_transformer() """Test that unflattened state dict matches original (unwrapped) one."""
ref_state_dict = module.state_dict() modules_to_test = [init_fn() for init_fn in self._get_module_init_fns()]
for module in modules_to_test:
ref_state_dict = module.state_dict()
flat_module = FlattenParamsWrapper(module) flat_module = FlattenParamsWrapper(module)
flat_state_dict = flat_module.state_dict() flat_state_dict = flat_module.state_dict()
assert objects_are_equal(ref_state_dict, flat_state_dict) assert (
ref_state_dict.keys() == flat_state_dict.keys()
), f"{ref_state_dict.keys()} != {flat_state_dict.keys()}"
assert objects_are_equal(ref_state_dict, flat_state_dict), f"{ref_state_dict} != {flat_state_dict}"
def test_load_state_dict(self): def test_load_state_dict(self):
module = self._get_shared_params_transformer() """Test that original (unwrapped) state_dict can be loaded in wrapped module."""
ref_state_dict = module.state_dict() for module_init_fn in self._get_module_init_fns():
ref_output = self._get_output(module) module = module_init_fn()
ref_state_dict = module.state_dict()
module = self._get_shared_params_transformer(seed=1234) ref_output = self._get_output(module)
flat_module = FlattenParamsWrapper(module)
flat_module.load_state_dict(ref_state_dict) module = module_init_fn(seed=1234)
flat_output = self._get_output(flat_module) flat_module = FlattenParamsWrapper(module)
assert objects_are_equal(ref_output, flat_output) # This should work without the unflatten_params context manager
flat_module.load_state_dict(ref_state_dict)
flat_output = self._get_output(flat_module)
assert objects_are_equal(ref_output, flat_output)
# And it should work with the context manager too
with flat_module.unflatten_params():
flat_module.load_state_dict(ref_state_dict)
flat_output = self._get_output(flat_module)
assert objects_are_equal(ref_output, flat_output)
def test_flat_state_dict(self): def test_flat_state_dict(self):
flat_module = self._get_shared_params_transformer() """Test that flat state dict can be reloaded and produces the same results."""
flat_module = FlattenParamsWrapper(flat_module) for module_init_fn in self._get_module_init_fns():
ref_output = self._get_output(flat_module) flat_module = FlattenParamsWrapper(module_init_fn())
ref_output = self._get_output(flat_module)
flat_state_dict = flat_module.flat_state_dict() flat_state_dict = flat_module.flat_state_dict()
new_module = self._get_shared_params_transformer(seed=1234) new_module = FlattenParamsWrapper(module_init_fn(seed=1234))
new_module = FlattenParamsWrapper(new_module) new_module.load_state_dict(flat_state_dict)
new_module.load_state_dict(flat_state_dict) new_output = self._get_output(new_module)
new_output = self._get_output(new_module)
assert objects_are_equal(ref_output, new_output) assert objects_are_equal(ref_output, new_output)
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test utility classes from state_dict.py. """
import torch
from torch import nn
from fairscale.utils.state_dict import find_module_instances, replace_by_prefix_
def test_find_module_instances():
net = nn.Sequential(
nn.Linear(1, 1), nn.ModuleDict({"ln": nn.LayerNorm(1), "linear": nn.Linear(1, 1)}), nn.LayerNorm(1)
)
assert find_module_instances(net, nn.LayerNorm) == [("1.ln.", net[1]["ln"]), ("2.", net[2])]
assert find_module_instances(net, nn.Linear) == [("0.", net[0]), ("1.linear.", net[1]["linear"])]
assert find_module_instances(net, nn.Dropout) == []
assert find_module_instances(net, nn.Sequential) == [("", net)]
def test_replace_by_prefix():
state_dict = {"layer.a": torch.tensor(1), "abc.layer.def": torch.tensor(2), "layer.b": torch.tensor(3)}
replace_by_prefix_(state_dict, "layer.", "module.layer.")
assert state_dict == {
"module.layer.a": torch.tensor(1),
"abc.layer.def": torch.tensor(2),
"module.layer.b": torch.tensor(3),
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment