Unverified Commit 03f3e833 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] a bug in loading flatten state dict (#1025)



- Thanks to Alexei for spotting this.
- Added triggering test case.
- enhanced tests with more comments etc.
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 5b5db28d
...@@ -494,8 +494,10 @@ class FlattenParamsWrapper(nn.Module): ...@@ -494,8 +494,10 @@ class FlattenParamsWrapper(nn.Module):
Load a state dict. If necessary, ``unflatten_params`` will be called to Load a state dict. If necessary, ``unflatten_params`` will be called to
match the input state_dict. match the input state_dict.
""" """
# unflatten the module automatically if the state_dict is non-flat # Unflatten the module automatically if the state_dict is non-flat.
if self.is_flattened and "flat_param_0" not in state_dict: # Note, we check the flat_param_ prefix since custom names can be given and flat_param_0 is
# not always in the state dict's key list.
if self.is_flattened and not any(k.startswith("flat_param_") for k in state_dict.keys()):
# This object is flatten but state_dict is not. So we unflatten and load. # This object is flatten but state_dict is not. So we unflatten and load.
with self.unflatten_params(): with self.unflatten_params():
return super().load_state_dict(state_dict, strict) return super().load_state_dict(state_dict, strict)
......
...@@ -21,6 +21,8 @@ class TestFlattenParams(unittest.TestCase): ...@@ -21,6 +21,8 @@ class TestFlattenParams(unittest.TestCase):
return [ return [
self._get_basic_linear_module, self._get_basic_linear_module,
self._get_shared_params_transformer, self._get_shared_params_transformer,
self._get_2_flatten_group_linear_module,
self._get_2_flatten_group_linear_module_with_names,
] ]
def _get_empty_module(self, seed=0): def _get_empty_module(self, seed=0):
...@@ -37,6 +39,8 @@ class TestFlattenParams(unittest.TestCase): ...@@ -37,6 +39,8 @@ class TestFlattenParams(unittest.TestCase):
return torch.rand(1).to(device=device, dtype=dtype) return torch.rand(1).to(device=device, dtype=dtype)
module.get_input = get_input module.get_input = get_input
module.param_list = None # No param_list to FPW.
module.flat_param_names = None # No flat_param_names to FPW.
return module return module
def _get_transformer(self, seed=0): def _get_transformer(self, seed=0):
...@@ -57,6 +61,8 @@ class TestFlattenParams(unittest.TestCase): ...@@ -57,6 +61,8 @@ class TestFlattenParams(unittest.TestCase):
return (src, tgt) return (src, tgt)
module.get_input = get_input module.get_input = get_input
module.param_list = None # No param_list to FPW.
module.flat_param_names = None # No flat_param_names to FPW.
return module return module
def _get_shared_params_transformer(self, seed=0): def _get_shared_params_transformer(self, seed=0):
...@@ -79,9 +85,43 @@ class TestFlattenParams(unittest.TestCase): ...@@ -79,9 +85,43 @@ class TestFlattenParams(unittest.TestCase):
return (torch.rand(8, 4).to(device=device, dtype=dtype),) return (torch.rand(8, 4).to(device=device, dtype=dtype),)
module.get_input = get_input module.get_input = get_input
module.param_list = None # No param_list to FPW.
module.flat_param_names = None # No flat_param_names to FPW.
return module return module
def _get_output(self, module): def _get_2_flatten_group_linear_module(self, seed=0):
module = torch.nn.Sequential(
torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 16)),
torch.nn.Linear(16, 4),
)
def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
return (torch.rand(8, 4).to(device=device, dtype=dtype),)
module.get_input = get_input
assert len(module) == 2, "next line assumes a len==2 sequential module"
module.param_list = [list(module[0].parameters()), list(module[1].parameters())]
module.flat_param_names = None # No flat_param_names to FPW.
return module
def _get_2_flatten_group_linear_module_with_names(self, seed=0):
module = torch.nn.Sequential(
torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 16)),
torch.nn.Linear(16, 4),
)
def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
return (torch.rand(8, 4).to(device=device, dtype=dtype),)
module.get_input = get_input
assert len(module) == 2, "next line assumes a len==2 sequential module"
module.param_list = [list(module[0].parameters()), list(module[1].parameters())]
module.flat_param_names = ["layer1", "layer2"]
return module
def _compute_output(self, module):
device = next(module.parameters()).device device = next(module.parameters()).device
dtype = next(module.parameters()).dtype dtype = next(module.parameters()).dtype
input = module.get_input(device, dtype) input = module.get_input(device, dtype)
...@@ -89,12 +129,13 @@ class TestFlattenParams(unittest.TestCase): ...@@ -89,12 +129,13 @@ class TestFlattenParams(unittest.TestCase):
def _get_pnorm_after_step(self, module): def _get_pnorm_after_step(self, module):
optim = torch.optim.SGD(module.parameters(), lr=0.01) optim = torch.optim.SGD(module.parameters(), lr=0.01)
loss = self._get_output(module).sum() loss = self._compute_output(module).sum()
loss.backward() loss.backward()
optim.step() optim.step()
return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()])) return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()]))
def _test_num_params(self, module): def _test_num_params(self, module):
"""Make sure numel of params are the same after flatten."""
ref_num_params = sum(p.numel() for p in module.parameters()) ref_num_params = sum(p.numel() for p in module.parameters())
flat_module = FlattenParamsWrapper(module) flat_module = FlattenParamsWrapper(module)
...@@ -104,13 +145,14 @@ class TestFlattenParams(unittest.TestCase): ...@@ -104,13 +145,14 @@ class TestFlattenParams(unittest.TestCase):
assert flat_num_params == flat_module.flat_param.numel() assert flat_num_params == flat_module.flat_param.numel()
def _test_output(self, module): def _test_output(self, module):
ref_output = self._get_output(module) ref_output = self._compute_output(module)
flat_module = FlattenParamsWrapper(module) flat_module = FlattenParamsWrapper(module)
flat_output = self._get_output(flat_module) flat_output = self._compute_output(flat_module)
assert objects_are_equal(ref_output, flat_output) assert objects_are_equal(ref_output, flat_output)
def test_partial_flattening(self): def test_partial_flattening(self):
"""Testing some parameters are flatten, with others left non-flatten."""
module = self._get_transformer() module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters()) num_params = sum(p.numel() for p in module.parameters())
...@@ -139,6 +181,7 @@ class TestFlattenParams(unittest.TestCase): ...@@ -139,6 +181,7 @@ class TestFlattenParams(unittest.TestCase):
assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters()) assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())
def test_two_flattening_group(self): def test_two_flattening_group(self):
"""Testing 2 flatten groups."""
module = self._get_transformer() module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters()) num_params = sum(p.numel() for p in module.parameters())
...@@ -153,8 +196,9 @@ class TestFlattenParams(unittest.TestCase): ...@@ -153,8 +196,9 @@ class TestFlattenParams(unittest.TestCase):
assert sum(p.numel() for p in module.parameters()) == num_params assert sum(p.numel() for p in module.parameters()) == num_params
def test_flatten_nothing(self): def test_flatten_nothing(self):
"""Testing nothing is flatten case."""
module = self._get_transformer() module = self._get_transformer()
ref_out = self._get_output(module) ref_out = self._compute_output(module)
ref_state_dict = module.state_dict() ref_state_dict = module.state_dict()
for k, v in ref_state_dict.items(): for k, v in ref_state_dict.items():
ref_state_dict[k] = v.clone() ref_state_dict[k] = v.clone()
...@@ -163,10 +207,11 @@ class TestFlattenParams(unittest.TestCase): ...@@ -163,10 +207,11 @@ class TestFlattenParams(unittest.TestCase):
assert ref_state_dict.keys() == fpw_state_dict.keys() assert ref_state_dict.keys() == fpw_state_dict.keys()
for k, v in ref_state_dict.items(): for k, v in ref_state_dict.items():
torch.testing.assert_allclose(v, fpw_state_dict[k]) torch.testing.assert_allclose(v, fpw_state_dict[k])
fpw_out = self._get_output(module) fpw_out = self._compute_output(module)
torch.testing.assert_allclose(ref_out, fpw_out) torch.testing.assert_allclose(ref_out, fpw_out)
def test_empty_module(self): def test_empty_module(self):
"""Test module without any param."""
module = self._get_empty_module() module = self._get_empty_module()
in_data = torch.rand(1) in_data = torch.rand(1)
ref_out = module(in_data) ref_out = module(in_data)
...@@ -223,69 +268,96 @@ class TestFlattenParams(unittest.TestCase): ...@@ -223,69 +268,96 @@ class TestFlattenParams(unittest.TestCase):
for module_init_fn in self._get_module_init_fns(): for module_init_fn in self._get_module_init_fns():
module = module_init_fn() module = module_init_fn()
ref_state_dict = module.state_dict() ref_state_dict = module.state_dict()
ref_output = self._get_output(module) ref_output = self._compute_output(module)
module = module_init_fn(seed=1234) module = module_init_fn(seed=1234)
flat_module = FlattenParamsWrapper(module) flat_module = FlattenParamsWrapper(
module, param_list=module.param_list, flat_param_names=module.flat_param_names
)
# This should work without the unflatten_params context manager # This should work without the unflatten_params context manager
flat_module.load_state_dict(ref_state_dict) flat_module.load_state_dict(ref_state_dict)
flat_output = self._get_output(flat_module) flat_output = self._compute_output(flat_module)
assert objects_are_equal(ref_output, flat_output) assert objects_are_equal(ref_output, flat_output)
# And it should work with the context manager too # And it should work with the context manager too
with flat_module.unflatten_params(): with flat_module.unflatten_params():
flat_module.load_state_dict(ref_state_dict) flat_module.load_state_dict(ref_state_dict)
flat_output = self._get_output(flat_module) flat_output = self._compute_output(flat_module)
assert objects_are_equal(ref_output, flat_output) assert objects_are_equal(ref_output, flat_output)
def test_flat_state_dict(self): def test_flat_state_dict(self):
"""Test that flat state dict can be reloaded and produces the same results.""" """Test that flat state dict can be reloaded and produces the same results."""
for module_init_fn in self._get_module_init_fns(): for module_init_fn in self._get_module_init_fns():
flat_module = FlattenParamsWrapper(module_init_fn()) orig_module = module_init_fn()
ref_output = self._get_output(flat_module) flat_module = FlattenParamsWrapper(
orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
)
ref_output = self._compute_output(flat_module)
flat_state_dict = flat_module.flat_state_dict() flat_state_dict = flat_module.flat_state_dict()
new_module = FlattenParamsWrapper(module_init_fn(seed=1234)) orig_module = module_init_fn(seed=1234)
new_module = FlattenParamsWrapper(
orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
)
new_module.load_state_dict(flat_state_dict) new_module.load_state_dict(flat_state_dict)
new_output = self._get_output(new_module) new_output = self._compute_output(new_module)
assert objects_are_equal(ref_output, new_output) assert objects_are_equal(ref_output, new_output)
def test_unflatten_params(self): def test_unflatten_params(self):
"""Testing using external flatten params tensors as module's params' backing data."""
for module_init_fn in self._get_module_init_fns(): for module_init_fn in self._get_module_init_fns():
module = FlattenParamsWrapper(module_init_fn()) orig_module = module_init_fn()
module = FlattenParamsWrapper(
orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
)
# keep a list of buffer's key to be used for verification below.
buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()} buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()}
def clone_state_dict(): def clone_state_dict():
"""Return a copy of the module's current state via state_dict() API."""
return OrderedDict((k, v.clone()) for k, v in module.state_dict().items()) return OrderedDict((k, v.clone()) for k, v in module.state_dict().items())
ref_flat_param = module.flat_param.clone() ref_flat_params = [fp.clone() for fp in module.flat_params]
# Get the current state as a reference.
with module.unflatten_params(): with module.unflatten_params():
ref_state_dict = clone_state_dict() ref_state_dict = clone_state_dict()
assert not torch.all(ref_flat_param == 0) for ref_fp in ref_flat_params:
assert not torch.all(ref_fp == 0.0) # Should not all be 0s.
# confirm that unflatten_params reflects values from new_flat_param # get new_state_dict with supplied new_flat_params.
new_flat_param = torch.full_like(module.flat_param, fill_value=42.0) new_flat_params = [torch.full_like(fp, fill_value=42.0) for fp in module.flat_params]
with module.unflatten_params(flat_params=[new_flat_param]): with module.unflatten_params(flat_params=new_flat_params):
new_state_dict = clone_state_dict() new_state_dict = clone_state_dict()
assert new_state_dict.keys() == ref_state_dict.keys()
for k, v in new_state_dict.items(): # confirm that unflatten_params reflects values from new_flat_param
if k in buffers: # buffers are not changed assert new_state_dict.keys() == ref_state_dict.keys()
torch.testing.assert_allclose(v, ref_state_dict[k]) for k, v in new_state_dict.items():
else: # params reflect new_flat_param value if k in buffers: # buffers are not changed
torch.testing.assert_allclose(v, torch.ones_like(v) * 42.0) torch.testing.assert_allclose(v, ref_state_dict[k])
else: # params reflect new_flat_param value
torch.testing.assert_allclose(v, torch.ones_like(v) * 42.0)
# after context manager exits, we go back to previous (reference) state # after context manager exits, we go back to previous (reference) state
torch.testing.assert_allclose(module.flat_param, ref_flat_param) assert len(module.flat_params) == len(ref_flat_params)
for i in range(len(module.flat_params)):
torch.testing.assert_allclose(module.flat_params[i], ref_flat_params[i])
# get another copy of state from the module (without external backing data)
with module.unflatten_params(): with module.unflatten_params():
ref_state_dict2 = clone_state_dict() ref_state_dict2 = clone_state_dict()
assert objects_are_equal(ref_state_dict, ref_state_dict2)
# Verify it is still the same.
assert objects_are_equal(ref_state_dict, ref_state_dict2)
# if we load the new_state_dict, then the flat param should match new_flat_param # if we load the new_state_dict, then the flat param should match new_flat_param
module.load_state_dict(new_state_dict) module.load_state_dict(new_state_dict)
torch.testing.assert_allclose(module.flat_param, new_flat_param) assert len(module.flat_params) == len(new_flat_params)
for i in range(len(module.flat_params)):
torch.testing.assert_allclose(module.flat_params[i], new_flat_params[i])
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment