Unverified Commit 03f3e833 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] a bug in loading flatten state dict (#1025)



- Thanks to Alexei for spotting this.
- Added triggering test case.
- enhanced tests with more comments etc.
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 5b5db28d
......@@ -494,8 +494,10 @@ class FlattenParamsWrapper(nn.Module):
Load a state dict. If necessary, ``unflatten_params`` will be called to
match the input state_dict.
"""
# unflatten the module automatically if the state_dict is non-flat
if self.is_flattened and "flat_param_0" not in state_dict:
# Unflatten the module automatically if the state_dict is non-flat.
# Note, we check the flat_param_ prefix since custom names can be given and flat_param_0 is
# not always in the state dict's key list.
if self.is_flattened and not any(k.startswith("flat_param_") for k in state_dict.keys()):
# This object is flatten but state_dict is not. So we unflatten and load.
with self.unflatten_params():
return super().load_state_dict(state_dict, strict)
......
......@@ -21,6 +21,8 @@ class TestFlattenParams(unittest.TestCase):
return [
self._get_basic_linear_module,
self._get_shared_params_transformer,
self._get_2_flatten_group_linear_module,
self._get_2_flatten_group_linear_module_with_names,
]
def _get_empty_module(self, seed=0):
......@@ -37,6 +39,8 @@ class TestFlattenParams(unittest.TestCase):
return torch.rand(1).to(device=device, dtype=dtype)
module.get_input = get_input
module.param_list = None # No param_list to FPW.
module.flat_param_names = None # No flat_param_names to FPW.
return module
def _get_transformer(self, seed=0):
......@@ -57,6 +61,8 @@ class TestFlattenParams(unittest.TestCase):
return (src, tgt)
module.get_input = get_input
module.param_list = None # No param_list to FPW.
module.flat_param_names = None # No flat_param_names to FPW.
return module
def _get_shared_params_transformer(self, seed=0):
......@@ -79,9 +85,43 @@ class TestFlattenParams(unittest.TestCase):
return (torch.rand(8, 4).to(device=device, dtype=dtype),)
module.get_input = get_input
module.param_list = None # No param_list to FPW.
module.flat_param_names = None # No flat_param_names to FPW.
return module
def _get_output(self, module):
def _get_2_flatten_group_linear_module(self, seed=0):
module = torch.nn.Sequential(
torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 16)),
torch.nn.Linear(16, 4),
)
def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
return (torch.rand(8, 4).to(device=device, dtype=dtype),)
module.get_input = get_input
assert len(module) == 2, "next line assumes a len==2 sequential module"
module.param_list = [list(module[0].parameters()), list(module[1].parameters())]
module.flat_param_names = None # No flat_param_names to FPW.
return module
def _get_2_flatten_group_linear_module_with_names(self, seed=0):
module = torch.nn.Sequential(
torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 16)),
torch.nn.Linear(16, 4),
)
def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
return (torch.rand(8, 4).to(device=device, dtype=dtype),)
module.get_input = get_input
assert len(module) == 2, "next line assumes a len==2 sequential module"
module.param_list = [list(module[0].parameters()), list(module[1].parameters())]
module.flat_param_names = ["layer1", "layer2"]
return module
def _compute_output(self, module):
device = next(module.parameters()).device
dtype = next(module.parameters()).dtype
input = module.get_input(device, dtype)
......@@ -89,12 +129,13 @@ class TestFlattenParams(unittest.TestCase):
def _get_pnorm_after_step(self, module):
optim = torch.optim.SGD(module.parameters(), lr=0.01)
loss = self._get_output(module).sum()
loss = self._compute_output(module).sum()
loss.backward()
optim.step()
return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()]))
def _test_num_params(self, module):
"""Make sure numel of params are the same after flatten."""
ref_num_params = sum(p.numel() for p in module.parameters())
flat_module = FlattenParamsWrapper(module)
......@@ -104,13 +145,14 @@ class TestFlattenParams(unittest.TestCase):
assert flat_num_params == flat_module.flat_param.numel()
def _test_output(self, module):
ref_output = self._get_output(module)
ref_output = self._compute_output(module)
flat_module = FlattenParamsWrapper(module)
flat_output = self._get_output(flat_module)
flat_output = self._compute_output(flat_module)
assert objects_are_equal(ref_output, flat_output)
def test_partial_flattening(self):
"""Testing some parameters are flatten, with others left non-flatten."""
module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters())
......@@ -139,6 +181,7 @@ class TestFlattenParams(unittest.TestCase):
assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())
def test_two_flattening_group(self):
"""Testing 2 flatten groups."""
module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters())
......@@ -153,8 +196,9 @@ class TestFlattenParams(unittest.TestCase):
assert sum(p.numel() for p in module.parameters()) == num_params
def test_flatten_nothing(self):
"""Testing nothing is flatten case."""
module = self._get_transformer()
ref_out = self._get_output(module)
ref_out = self._compute_output(module)
ref_state_dict = module.state_dict()
for k, v in ref_state_dict.items():
ref_state_dict[k] = v.clone()
......@@ -163,10 +207,11 @@ class TestFlattenParams(unittest.TestCase):
assert ref_state_dict.keys() == fpw_state_dict.keys()
for k, v in ref_state_dict.items():
torch.testing.assert_allclose(v, fpw_state_dict[k])
fpw_out = self._get_output(module)
fpw_out = self._compute_output(module)
torch.testing.assert_allclose(ref_out, fpw_out)
def test_empty_module(self):
"""Test module without any param."""
module = self._get_empty_module()
in_data = torch.rand(1)
ref_out = module(in_data)
......@@ -223,53 +268,72 @@ class TestFlattenParams(unittest.TestCase):
for module_init_fn in self._get_module_init_fns():
module = module_init_fn()
ref_state_dict = module.state_dict()
ref_output = self._get_output(module)
ref_output = self._compute_output(module)
module = module_init_fn(seed=1234)
flat_module = FlattenParamsWrapper(module)
flat_module = FlattenParamsWrapper(
module, param_list=module.param_list, flat_param_names=module.flat_param_names
)
# This should work without the unflatten_params context manager
flat_module.load_state_dict(ref_state_dict)
flat_output = self._get_output(flat_module)
flat_output = self._compute_output(flat_module)
assert objects_are_equal(ref_output, flat_output)
# And it should work with the context manager too
with flat_module.unflatten_params():
flat_module.load_state_dict(ref_state_dict)
flat_output = self._get_output(flat_module)
flat_output = self._compute_output(flat_module)
assert objects_are_equal(ref_output, flat_output)
def test_flat_state_dict(self):
"""Test that flat state dict can be reloaded and produces the same results."""
for module_init_fn in self._get_module_init_fns():
flat_module = FlattenParamsWrapper(module_init_fn())
ref_output = self._get_output(flat_module)
orig_module = module_init_fn()
flat_module = FlattenParamsWrapper(
orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
)
ref_output = self._compute_output(flat_module)
flat_state_dict = flat_module.flat_state_dict()
new_module = FlattenParamsWrapper(module_init_fn(seed=1234))
orig_module = module_init_fn(seed=1234)
new_module = FlattenParamsWrapper(
orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
)
new_module.load_state_dict(flat_state_dict)
new_output = self._get_output(new_module)
new_output = self._compute_output(new_module)
assert objects_are_equal(ref_output, new_output)
def test_unflatten_params(self):
"""Testing using external flatten params tensors as module's params' backing data."""
for module_init_fn in self._get_module_init_fns():
module = FlattenParamsWrapper(module_init_fn())
orig_module = module_init_fn()
module = FlattenParamsWrapper(
orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
)
# keep a list of buffer's key to be used for verification below.
buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()}
def clone_state_dict():
"""Return a copy of the module's current state via state_dict() API."""
return OrderedDict((k, v.clone()) for k, v in module.state_dict().items())
ref_flat_param = module.flat_param.clone()
ref_flat_params = [fp.clone() for fp in module.flat_params]
# Get the current state as a reference.
with module.unflatten_params():
ref_state_dict = clone_state_dict()
assert not torch.all(ref_flat_param == 0)
for ref_fp in ref_flat_params:
assert not torch.all(ref_fp == 0.0) # Should not all be 0s.
# confirm that unflatten_params reflects values from new_flat_param
new_flat_param = torch.full_like(module.flat_param, fill_value=42.0)
with module.unflatten_params(flat_params=[new_flat_param]):
# get new_state_dict with supplied new_flat_params.
new_flat_params = [torch.full_like(fp, fill_value=42.0) for fp in module.flat_params]
with module.unflatten_params(flat_params=new_flat_params):
new_state_dict = clone_state_dict()
# confirm that unflatten_params reflects values from new_flat_param
assert new_state_dict.keys() == ref_state_dict.keys()
for k, v in new_state_dict.items():
if k in buffers: # buffers are not changed
......@@ -278,14 +342,22 @@ class TestFlattenParams(unittest.TestCase):
torch.testing.assert_allclose(v, torch.ones_like(v) * 42.0)
# after context manager exits, we go back to previous (reference) state
torch.testing.assert_allclose(module.flat_param, ref_flat_param)
assert len(module.flat_params) == len(ref_flat_params)
for i in range(len(module.flat_params)):
torch.testing.assert_allclose(module.flat_params[i], ref_flat_params[i])
# get another copy of state from the module (without external backing data)
with module.unflatten_params():
ref_state_dict2 = clone_state_dict()
# Verify it is still the same.
assert objects_are_equal(ref_state_dict, ref_state_dict2)
# if we load the new_state_dict, then the flat param should match new_flat_param
module.load_state_dict(new_state_dict)
torch.testing.assert_allclose(module.flat_param, new_flat_param)
assert len(module.flat_params) == len(new_flat_params)
for i in range(len(module.flat_params)):
torch.testing.assert_allclose(module.flat_params[i], new_flat_params[i])
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment