"docs/vscode:/vscode.git/clone" did not exist on "7d7ae0a1b0df87ce8ac123cd8b97ade6b15bac2f"
Unverified Commit 35fdf537 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Add FlattenParamsWrapper (#317)

parent 81841734
...@@ -472,3 +472,27 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER ...@@ -472,3 +472,27 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
------------- LICENSE FOR FlattenParamsWrapper --------------
MIT License
Copyright (c) 2018 Tongzhou Wang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
...@@ -174,6 +174,8 @@ fairscale.nn.model_parallel is forked from [Megatron-LM](https://github.com/NVID ...@@ -174,6 +174,8 @@ fairscale.nn.model_parallel is forked from [Megatron-LM](https://github.com/NVID
fairscale.optim.adascale is forked from [AdaptDL](https://github.com/petuum/adaptdl), Copyright 2020, Petuum, Inc., licensed under [Apache License](http://www.apache.org/licenses/LICENSE-2.0). fairscale.optim.adascale is forked from [AdaptDL](https://github.com/petuum/adaptdl), Copyright 2020, Petuum, Inc., licensed under [Apache License](http://www.apache.org/licenses/LICENSE-2.0).
fairscale.nn.misc.flatten_params_wrapper is forked from [PyTorch-Reparam-Module](https://github.com/SsnL/PyTorch-Reparam-Module), Copyright 2018, Tongzhou Wang, licensed under [MIT License](https://github.com/SsnL/PyTorch-Reparam-Module/blob/master/LICENSE).
## References ## References
Here is a list of all authors on relevant research papers this work is based on: Here is a list of all authors on relevant research papers this work is based on:
......
...@@ -4,7 +4,15 @@ ...@@ -4,7 +4,15 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .data_parallel import ShardedDataParallel from .data_parallel import ShardedDataParallel
from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import LazyModule, Pipe, PipeRPCWrapper from .pipe import LazyModule, Pipe, PipeRPCWrapper
__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule", "ShardedDataParallel"] __all__ = [
"FlattenParamsWrapper",
"LazyModule",
"Pipe",
"PipeRPCWrapper",
"ShardedDataParallel",
"Top2Gate",
]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .flatten_params_wrapper import FlattenParamsWrapper
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
from collections import namedtuple
from contextlib import contextmanager
from typing import List, Optional
import torch
import torch.nn as nn
class FlattenParamsWrapper(nn.Module):
"""
A wrapper for transparently flattening a Module's parameters.
Compared to the original implementation [1], this version:
- removes tracing
- supports shared parameters
- handles state_dict/load_state_dict transparently
- is renamed to FlattenParamsWrapper
[1] https://github.com/SsnL/PyTorch-Reparam-Module
Args:
module (nn.Module): module to wrap
param_list (Optional[List[nn.Parameter]]): only flatten parameters
appearing in the given list (default: flatten all parameters)
"""
def __init__(
self, module: nn.Module, param_list: Optional[List[nn.Parameter]] = None
):
super().__init__()
self.module = module
if param_list is not None:
assert len(param_list) > 0, "param_list can't be empty"
else:
param_list = module.parameters()
param_list = set(param_list)
# convert from list of Parameters to set of (Module, name) tuples, which
# will survive in case the Parameter instances are reset
self._param_list = set()
for m in self.modules():
for n, p in m.named_parameters(recurse=False):
if p in param_list:
self._param_list.add((m, n))
self._flatten_params()
# register the views as plain attributes
self._unflatten_params_as_views()
def _flatten_params(self):
param_infos = []
shared_param_memo = {}
shared_param_infos = []
params = []
param_numels = []
param_shapes = []
for m in self.modules():
for n, p in m.named_parameters(recurse=False):
if p is not None and (m, n) in self._param_list:
if p in shared_param_memo:
shared_m, shared_n = shared_param_memo[p]
shared_param_infos.append((m, n, shared_m, shared_n))
else:
shared_param_memo[p] = (m, n)
param_infos.append((m, n))
params.append(p.detach())
param_numels.append(p.numel())
param_shapes.append(p.size())
del shared_param_memo
assert (
len(set(p.dtype for p in params)) <= 1
), "expects all parameters in module to have same dtype"
# store the info for unflatten
self._param_infos = tuple(param_infos)
self._shared_param_infos = tuple(shared_param_infos)
self._param_numels = tuple(param_numels)
self._param_shapes = tuple(param_shapes)
# flatten
flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
self.register_parameter("flat_param", flat_param)
self.param_numel = flat_param.numel()
del params
# deregister the names as parameters
for m, n in self._param_infos:
delattr(m, n)
for m, n, _, _ in self._shared_param_infos:
delattr(m, n)
def _get_param_views(self):
return (
t.view(s)
for (t, s) in zip(
self.flat_param.split(self._param_numels), self._param_shapes
)
)
def _unflatten_params(self):
ps = self._get_param_views()
for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n):
delattr(m, n)
m.register_parameter(n, nn.Parameter(p))
for (m, n, shared_m, shared_n) in self._shared_param_infos:
if hasattr(m, n):
delattr(m, n)
m.register_parameter(n, getattr(shared_m, shared_n))
del self.flat_param
def _unflatten_params_as_views(self):
ps = self._get_param_views()
for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n))
@contextmanager
def unflatten_params(self):
self._unflatten_params()
yield
self._flatten_params()
self._unflatten_params_as_views()
def __getattr__(self, name):
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name) # fallback to wrapped module
def state_dict(self, *args, unflatten_params=True, **kwargs):
if unflatten_params:
with self.unflatten_params():
return self.module.state_dict()
else:
return super().state_dict()
def load_state_dict(self, state_dict, *args, **kwargs):
if "flat_param" in state_dict:
super().load_state_dict(state_dict, strict=True)
else:
with self.unflatten_params():
return self.module.load_state_dict(state_dict, *args, **kwargs)
def forward(self, *inputs, **kwinputs):
self._unflatten_params_as_views()
return self.module(*inputs, **kwinputs)
...@@ -364,3 +364,40 @@ class GPT2(nn.Module): ...@@ -364,3 +364,40 @@ class GPT2(nn.Module):
h = torch.mean(h, dim=0) # average pool over sequence h = torch.mean(h, dim=0) # average pool over sequence
# return classification logits and generative logits # return classification logits and generative logits
return self.clf_head(h), logits return self.clf_head(h), logits
def objects_are_equal(a, b, raise_exception=False) -> bool:
"""
Test that two objects are equal. Tensors are compared to ensure matching
size, dtype, device and values.
"""
if type(a) is not type(b):
return False
if isinstance(a, dict):
if set(a.keys()) != set(b.keys()):
return False
for k in a.keys():
if not objects_are_equal(a[k], b[k], raise_exception):
return False
return True
elif isinstance(a, (list, tuple, set)):
if len(a) != len(b):
return False
return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
elif torch.is_tensor(a):
try:
torch.testing.assert_allclose(a, b)
# assert_allclose doesn't strictly test shape, dtype and device
shape_dtype_device_match = (
a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
)
assert shape_dtype_device_match
return True
except AssertionError as e:
if raise_exception:
raise e
else:
return False
else:
return a == b
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Test FlattenParamsWrapper
"""
import unittest
import torch
from fairscale.nn import FlattenParamsWrapper
from fairscale.utils.testing import objects_are_equal
class TestFlattenParams(unittest.TestCase):
def _get_transformer(self, seed=0):
torch.manual_seed(seed) # keep everything deterministic
module = torch.nn.Transformer(
d_model=32,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=128,
dropout=0.1,
)
module.register_buffer("dummy_buffer", torch.tensor(1.0))
return module
def _get_shared_params_transformer(self, seed=0):
module = self._get_transformer(seed=seed)
# share the FFNs
for enc_layer, dec_layer in zip(module.encoder.layers, module.decoder.layers):
dec_layer.linear1.weight = enc_layer.linear1.weight
dec_layer.linear2.weight = enc_layer.linear2.weight
return module
def _get_output(self, module):
torch.manual_seed(1) # keep everything deterministic
device = next(module.parameters()).device
dtype = next(module.parameters()).dtype
src = torch.rand(20, 8, 32).to(device=device, dtype=dtype) # T x B x C
tgt = torch.rand(10, 8, 32).to(device=device, dtype=dtype) # T x B x C
return module(src, tgt)
def _get_pnorm_after_step(self, module):
optim = torch.optim.SGD(module.parameters(), lr=0.01)
loss = self._get_output(module).sum()
loss.backward()
optim.step()
return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()]))
def _test_num_params(self, module):
ref_num_params = sum(p.numel() for p in module.parameters())
flat_module = FlattenParamsWrapper(module)
flat_num_params = sum(p.numel() for p in flat_module.parameters())
assert ref_num_params == flat_num_params
assert flat_num_params == flat_module.flat_param.numel()
def _test_output(self, module):
ref_output = self._get_output(module)
flat_module = FlattenParamsWrapper(module)
flat_output = self._get_output(flat_module)
assert objects_are_equal(ref_output, flat_output)
def test_partial_flattening(self):
module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters())
params_to_flatten = (
list(module.encoder.layers[1].parameters())
+ list(module.decoder.layers[0].parameters())
)
num_params_to_flatten = sum(p.numel() for p in params_to_flatten)
module = FlattenParamsWrapper(module, param_list=params_to_flatten)
assert module.flat_param.numel() == num_params_to_flatten
assert sum(p.numel() for p in module.parameters()) == num_params
# flattened parameters are removed
assert len(list(module.encoder.layers[1].parameters())) == 0
assert len(list(module.decoder.layers[0].parameters())) == 0
# non-flattened parameters remain
assert len(list(module.encoder.layers[0].parameters())) > 0
assert len(list(module.decoder.layers[1].parameters())) > 0
# test that changing the module dtype works properly
orig_dtype = params_to_flatten[0].dtype
new_dtype = torch.float32 if orig_dtype == torch.float16 else torch.float16
assert module.flat_param.dtype == orig_dtype
assert all(
p.dtype == orig_dtype for p in module.encoder.layers[0].parameters()
)
module = module.to(dtype=new_dtype)
assert module.flat_param.dtype == new_dtype
assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())
def test_num_params(self):
module = self._get_transformer()
self._test_num_params(module)
def test_shared_params_num_params(self):
module = self._get_shared_params_transformer()
self._test_num_params(module)
def test_output(self):
module = self._get_transformer()
self._test_output(module)
def test_shared_params_output(self):
module = self._get_shared_params_transformer()
self._test_output(module)
def test_shared_params_pnorm_after_step(self):
# incorrect parameter sharing is likely to cause problems after an
# optimization step
module = self._get_shared_params_transformer()
ref_pnorm_after_step = self._get_pnorm_after_step(module)
module = self._get_shared_params_transformer() # recreate
flat_module = FlattenParamsWrapper(module)
flat_pnorm_after_step = self._get_pnorm_after_step(flat_module)
torch.testing.assert_allclose(ref_pnorm_after_step, flat_pnorm_after_step)
def test_state_dict_equality(self):
module = self._get_shared_params_transformer()
ref_state_dict = module.state_dict()
flat_module = FlattenParamsWrapper(module)
flat_state_dict = flat_module.state_dict()
assert objects_are_equal(ref_state_dict, flat_state_dict)
def test_load_state_dict(self):
module = self._get_shared_params_transformer()
ref_state_dict = module.state_dict()
ref_output = self._get_output(module)
module = self._get_shared_params_transformer(seed=1234)
flat_module = FlattenParamsWrapper(module)
flat_module.load_state_dict(ref_state_dict)
flat_output = self._get_output(flat_module)
assert objects_are_equal(ref_output, flat_output)
def test_flat_state_dict(self):
flat_module = self._get_shared_params_transformer()
flat_module = FlattenParamsWrapper(flat_module)
ref_output = self._get_output(flat_module)
flat_state_dict = flat_module.state_dict(unflatten_params=False)
new_module = self._get_shared_params_transformer(seed=1234)
new_module = FlattenParamsWrapper(new_module)
new_module.load_state_dict(flat_state_dict)
new_output = self._get_output(new_module)
assert objects_are_equal(ref_output, new_output)
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestFlattenParamsCUDA(TestFlattenParams):
def _get_transformer(self, seed=0):
module = super()._get_transformer(seed=seed)
return module.cuda()
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestFlattenParamsCUDAHalf(TestFlattenParams):
def _get_transformer(self, seed=0):
module = super()._get_transformer(seed=seed)
return module.cuda().half()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment