Unverified Commit d2924670 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] Make state_dict all-gather FP32 params (#451)

parent f3359550
......@@ -180,7 +180,10 @@ class FullyShardedDataParallel(nn.Module):
params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))
self._has_params = len(params) > 0
if self.flatten_parameters and self._has_params:
if not self._has_params:
self.flatten_parameters = False
if self.flatten_parameters:
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
del module # free original module in case it helps garbage collection
self.params = [self._fsdp_wrapped_module.flat_param]
......@@ -335,22 +338,27 @@ class FullyShardedDataParallel(nn.Module):
continue
p._is_sharded = True
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(p.data).chunk(self.world_size))
while len(chunks) < self.world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
assert num_to_pad >= 0, num_to_pad
# Replace p.data with the relevant shard.
orig_data = p.data
p.data = chunks[self.rank].clone() # clone since we free storage below
if num_to_pad > 0:
p.data = F.pad(p.data, [0, num_to_pad])
p.data = self._get_shard(p.data)
free_storage_(orig_data)
def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
"""Return the local shard of a given full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(self.world_size))
while len(chunks) < self.world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
assert num_to_pad >= 0, num_to_pad
shard = chunks[self.rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard
def extra_repr(self) -> str:
return (
f"rank={self.rank}, world_size={self.world_size}, "
......@@ -408,32 +416,34 @@ class FullyShardedDataParallel(nn.Module):
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors
will always be typed float32.
will be full precision (e.g., FP32).
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch.cuda.synchronize()
self._lazy_init()
if self.mixed_precision:
# Buffers dtype stays consistent with parameters.
self._all_buffers_to(dtype=torch.float32)
if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
with self.summon_full_params():
with self.summon_full_params(volatile=True):
state_dict = super().state_dict(*args, **kwargs)
else:
torch.cuda.synchronize()
self._lazy_init()
state_dict = super().state_dict(*args, **kwargs)
else:
torch.cuda.synchronize()
self._lazy_init()
if self.flatten_parameters:
assert isinstance(self.module, FlattenParamsWrapper)
state_dict = self.module.flat_state_dict(*args, **kwargs)
else:
state_dict = super().state_dict(*args, **kwargs)
if self.cpu_offload:
for k in state_dict.keys():
state_dict[k] = state_dict[k].cpu()
if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.compute_dtype)
......@@ -516,29 +526,42 @@ class FullyShardedDataParallel(nn.Module):
m._require_backward_grad_sync = old_flag
@contextlib.contextmanager
def summon_full_params(self, recurse: bool = True) -> Generator:
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
"""
A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params for
additional processing or checking.
By default this will recursively summon all params for nested FSDP
instances; this can be disabled by setting ``recurse=False``.
additional processing or checking. Parameters will be gathered in full
precision (e.g., FP32).
.. note:: This can be used on inner FSDPs.
.. note:: This can *not* be used within a forward or backward pass. Nor
can forward and backward be started from within this context.
.. note:: The full parameters will be freed after the context manager
exits; it is up to the caller to clone them if needed.
.. note:: The full parameters can be modified, but only the portion
corresponding to the local param shard will persist after the
context manager exits (unless ``volatile=True``, in which case there
are no guarantees about persistence).
Args:
recurse (bool, Optional): recursively summon all params for nested
FSDP instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
if recurse:
with contextlib.ExitStack() as stack:
# summon all params for any nested FlattenParamsWrapper instances
# Summon all params for any nested FSDP instances.
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module.summon_full_params(recurse=False))
# yield to the caller, with full params in all nested instances
stack.enter_context(module.summon_full_params(recurse=False, volatile=volatile))
# Yield to the caller, with full params in all nested instances.
yield
# exiting from the ExitStack will re-shard params
# Exiting from the ExitStack will re-shard params.
return
else:
torch.cuda.synchronize()
......@@ -547,13 +570,30 @@ class FullyShardedDataParallel(nn.Module):
# Set the state so that we assert when trying to go into
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
self._rebuild_full_params()
try:
yield
finally:
self._free_full_params()
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
full_tensors = self._rebuild_full_params(full_precision=True)
with contextlib.ExitStack() as stack:
if self.flatten_parameters and self.module.is_flattened:
# Update flattened views to point to fully-sized tensors. We
# use self.params[0] instead of full_tensors since the
# latter may contain padding.
assert len(self.params) == 1
assert isinstance(self.module, FlattenParamsWrapper)
stack.enter_context(self.module.unflatten_params(recurse=False, flat_param=self.params[0]))
try:
yield
finally:
stack.close()
assert len(full_tensors) == len(self.params)
for p, (full_tensor, safe_to_free) in zip(self.params, full_tensors):
if not volatile:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
......@@ -953,35 +993,61 @@ class FullyShardedDataParallel(nn.Module):
m.training_state = TrainingState.IDLE
@torch.no_grad()
def _rebuild_full_params(self) -> None:
"""Gather all shards of params."""
def _rebuild_full_params(self, full_precision: bool = False) -> List[Tuple[torch.Tensor, bool]]:
"""
Gather all shards of params.
Args:
full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage.
Returns:
a list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param
"""
output_tensors: List[Tuple[torch.Tensor, bool]] = []
with torch.cuda.stream(self._streams["all_gather"]):
if self.mixed_precision:
if self.mixed_precision and not full_precision:
self._cast_fp32_param_shards_to_fp16()
for p in self.params:
if not p._is_sharded:
if self.mixed_precision:
if not p._is_sharded: # e.g., when world_size == 1
if self.mixed_precision and not full_precision:
p.data = p._fp16_shard
continue
p_size = p._full_param_padded.size()
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
assert p_size.numel() % self.world_size == 0
if p._is_sharded:
# Fill p._full_param_padded with (p.data for each shard in self.world_size)
chunks = list(p._full_param_padded.chunk(self.world_size))
dist.all_gather(chunks, p.data, group=self.process_group)
output_tensors.append((p.data, True))
else:
p._full_param_padded.copy_(torch.flatten(p.data), non_blocking=True)
output_tensors.append((p.data, False))
continue
p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size)
# If self.cpu_offload and full_precision, we need to cast the
# FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device)
if self.mixed_precision:
p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
if not self.mixed_precision or not full_precision:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded
else:
# Allocate fresh tensor in full precision.
output_tensor = p_data.new_zeros(p_size)
output_tensors.append((output_tensor, True))
# Fill output_tensor with (p.data for each shard in self.world_size)
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group)
p.data = output_tensor[: p._orig_size.numel()].view(p._orig_size)
if self.mixed_precision and not full_precision:
self._free_fp16_param_shard([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors
@torch.no_grad()
def _use_full_params(self) -> None:
......@@ -1013,7 +1079,7 @@ class FullyShardedDataParallel(nn.Module):
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self._streams["all_gather"]):
for p in params:
if not p._is_sharded:
if not p._is_sharded: # e.g., world_size == 1
if self.mixed_precision:
self._free_fp16_param_shard([p])
continue
......
......@@ -53,9 +53,6 @@ class FlattenParamsWrapper(nn.Module):
self._flatten_params()
# register the views as plain attributes
self._unflatten_params_as_views()
# Register hook to be called after state_dict() to remove the
# "_fpw_module." prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook)
......@@ -70,10 +67,7 @@ class FlattenParamsWrapper(nn.Module):
def module(self) -> nn.Module:
return self._fpw_module
def _flatten_params(self) -> None:
assert not self.is_flattened
self.is_flattened = True
def _init_flatten_params(self) -> List[Tensor]:
param_infos = []
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {}
shared_param_infos = []
......@@ -102,11 +96,22 @@ class FlattenParamsWrapper(nn.Module):
self._param_numels = tuple(param_numels)
self._param_shapes = tuple(param_shapes)
return params
def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None:
assert not self.is_flattened
self.is_flattened = True
if not hasattr(self, "_param_infos"):
assert flat_param is None
params = self._init_flatten_params()
flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
self.param_numel = flat_param.numel()
del params
# flatten
flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
assert flat_param is not None
self.register_parameter("flat_param", flat_param)
self.param_numel = flat_param.numel()
del params
# deregister the names as parameters
for m, n in self._param_infos:
......@@ -114,14 +119,18 @@ class FlattenParamsWrapper(nn.Module):
for m, n, _, _ in self._shared_param_infos:
delattr(m, n)
def _get_param_views(self) -> Generator:
return (t.view(s) for (t, s) in zip(self.flat_param.split(self._param_numels), self._param_shapes))
# register the views as plain attributes
self._unflatten_params_as_views()
def _unflatten_params(self) -> None:
assert self.is_flattened
def _get_param_views(self, flat_param: Tensor) -> Generator:
return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:
assert self.is_flattened or flat_param is not None
self.is_flattened = False
flat_param = flat_param if flat_param is not None else self.flat_param
ps = self._get_param_views()
ps = self._get_param_views(flat_param)
for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n):
delattr(m, n)
......@@ -130,41 +139,60 @@ class FlattenParamsWrapper(nn.Module):
if hasattr(m, n):
delattr(m, n)
m.register_parameter(n, getattr(shared_m, shared_n))
del self.flat_param
if hasattr(self, "flat_param"):
del self.flat_param
def _unflatten_params_as_views(self) -> None:
assert self.is_flattened
ps = self._get_param_views()
ps = self._get_param_views(self.flat_param)
for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n))
@contextmanager
def unflatten_params(self, recurse: bool = True) -> Generator:
def unflatten_params(self, recurse: bool = True, flat_param: Optional[Tensor] = None) -> Generator:
"""
Unflatten params (optionally recursively on all nested instances).
If the current instance is already unflattened, then it will remain
unflattened after the context manager exits.
Unflatten params. If the current instance is already unflattened, then
it will remain unflattened after the context manager exits.
Args:
recurse (bool, Optional): recursively unflatten all nested instances
(default: True)
flat_param (Tensor, Optional): flat param to use for unflattening.
If provided, the current instance must be in a flattened state
at the start of the context manager. The provided Tensor must be
appropriately sized and will only be used within the context
manager. After the context manager exits, we will revert to
using ``self.flat_param`` (default: None).
"""
if recurse:
with ExitStack() as stack:
# unflatten any nested FlattenParamsWrapper instances
for module in self.modules():
for name, module in self.named_modules():
if isinstance(module, FlattenParamsWrapper):
stack.enter_context(module.unflatten_params(recurse=False))
is_self = name == ""
stack.enter_context(
module.unflatten_params(recurse=False, flat_param=flat_param if is_self else None)
)
# yield to the caller, with unflattened params in all nested instances
yield
# exiting from the ExitStack will re-flatten params
return
else:
assert (
flat_param is None or self.is_flattened
), "Unflattening with custom flat_param requires current instance to be flattened"
orig_flattened = self.is_flattened
if self.is_flattened:
self._unflatten_params()
if orig_flattened:
orig_flat_param = self.flat_param
self._unflatten_params(flat_param)
yield
if orig_flattened:
self._flatten_params()
self._unflatten_params_as_views()
self._flatten_params(orig_flat_param)
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
......
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/wrap/test_wrap.py
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import itertools
from math import inf
......@@ -182,8 +183,13 @@ class TestComparisonToPyTorchDDP(DistributedTest):
PyTorch DDP vs. FullyShardedDataParallel.
"""
def test_nested_all_wrapped_model(self):
config = {"mixed_precision": True}
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model(self, config):
test_fn = functools.partial(self._test_identical_outputs, NestedWrappedModule, config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_all_wrapped_model(self, config):
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
......@@ -280,6 +286,9 @@ class TestComparisonToPyTorchDDP(DistributedTest):
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()
# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
......@@ -456,9 +465,8 @@ class TestSaveLoadStateDict(DistributedTest):
def _test_state_dict_before_forward(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
sd = ddp_model.state_dict()
expected_dtype = torch.float16 if ddp_model.mixed_precision else torch.float32
wt = sd["embed_tokens.weight"]
assert wt.dtype == expected_dtype, f"got dtype {wt.dtype} expected {expected_dtype}"
assert wt.dtype == torch.float32, f"got dtype {wt.dtype} expected torch.float32"
cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision)
@classmethod
......@@ -480,15 +488,11 @@ class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model(self, config):
if config["mixed_precision"]:
return # TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn = functools.partial(self._test_nested_wrapped_model, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model_local_state_dict(self, config):
if config["mixed_precision"]:
return # TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn = functools.partial(self._test_nested_wrapped_model_local_state_dict, config=config)
spawn_and_init(test_fn)
......@@ -501,6 +505,8 @@ class TestSaveLoadStateDict(DistributedTest):
ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()}
# Create a nested FSDP-wrapped instance.
if config["mixed_precision"]:
config["compute_dtype"] = torch.float32
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import gc
import unittest
from parameterized import parameterized
import torch
from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init
def get_cuda_mem():
torch.cuda.synchronize()
gc.collect()
return torch.cuda.memory_allocated()
class TestMemory(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_memory(self, config):
spawn_and_init(functools.partial(self._test_memory, config))
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_memory_volatile(self, config):
spawn_and_init(functools.partial(self._test_memory, config, volatile=True))
@classmethod
def _test_memory(self, config, rank, group, volatile=False):
model = self.get_wrapped_model(group, cuda_first=False, config=config)
self._train_for_several_steps(model, 1, autocast=model.mixed_precision)
mems = [get_cuda_mem()]
with model.summon_full_params(volatile=volatile):
mems.append(get_cuda_mem())
assert mems[1] >= mems[0]
state_dict = model.state_dict()
mems.append(get_cuda_mem())
assert mems[2] >= mems[1]
mems.append(get_cuda_mem())
assert mems[3] <= mems[2]
del state_dict
mems.append(get_cuda_mem())
assert mems[4] == mems[0]
class TestPersistence(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_non_volatile(self, config):
spawn_and_init(functools.partial(self._test_persistence, config))
@classmethod
def _test_persistence(self, config, rank, group, volatile=False):
model = self.get_wrapped_model(group, cuda_first=False, config=config)
with model.summon_full_params(volatile=False):
model.module.embed_tokens.weight.data.fill_(42)
with model.summon_full_params():
# non-volatile changes are persisted
assert torch.all(model.module.embed_tokens.weight.data == 42.0)
if __name__ == "__main__":
unittest.main()
......@@ -29,17 +29,26 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
my_lr = 0.1
device = torch.device("cuda")
if fsdp_config.get("mixed_precision", False):
dtype = torch.float16
fsdp_config["fp32_reduce_scatter"] = True
else:
dtype = torch.float32
if test_case["assert_ref_out"]:
with torch.no_grad():
# Compute one iteration local output.
weight = model.weight.T.clone().cuda()
v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
fp32_weight = model.weight.T.clone().to(device)
weight = fp32_weight.to(dtype)
v = torch.Tensor(test_case["inputs"][0][rank]).to(device, dtype)
ref_forward_output_my_rank = torch.matmul(v, weight)
# Compute one iteration global weight update.
v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda()
grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size)
ref_weight_out = weight - grad.T * my_lr
model.to("cuda")
v = torch.Tensor(test_case["inputs"][0][:world_size]).to(device, dtype)
grad = v.float().sum(0).repeat(weight.shape[0], 1).div(world_size)
ref_weight_out = fp32_weight - grad.T * my_lr
assert ref_weight_out.dtype == torch.float32
model.to(device) # not dtype, since FSDP will manage mixed precision internally
assert isinstance(fsdp_config, dict), str(fsdp_config)
model = FSDP(model, **fsdp_config)
optim = SGD(model.parameters(), lr=my_lr)
......@@ -47,9 +56,9 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
assert len(inputs) == 1 or not test_case["assert_ref_out"]
assert len(inputs[0]) >= world_size
for in_data in inputs:
in_data = Tensor(in_data[rank]).cuda()
in_data = Tensor(in_data[rank]).to(device, dtype)
out = model(in_data)
out.sum().backward()
out.float().sum().backward()
optim.step()
optim.zero_grad()
if test_case["assert_ref_out"]:
......@@ -70,7 +79,7 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
@skip_if_single_gpu
@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}])
@pytest.mark.parametrize(
"fsdp_config", [{}, {"flatten_parameters": False}],
"fsdp_config", [{}, {"flatten_parameters": False}, {"mixed_precision": True}],
)
@pytest.mark.parametrize("world_size", list(range(2, 9)))
def test_one_iteration(world_size, test_case, fsdp_config):
......
......@@ -7,6 +7,7 @@
Test FlattenParamsWrapper
"""
from collections import OrderedDict
import unittest
import torch
......@@ -196,6 +197,40 @@ class TestFlattenParams(unittest.TestCase):
assert objects_are_equal(ref_output, new_output)
def test_unflatten_params(self):
for module_init_fn in self._get_module_init_fns():
module = FlattenParamsWrapper(module_init_fn())
buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()}
def clone_state_dict():
return OrderedDict((k, v.clone()) for k, v in module.state_dict().items())
ref_flat_param = module.flat_param.clone()
with module.unflatten_params():
ref_state_dict = clone_state_dict()
assert not torch.all(ref_flat_param == 0)
# confirm that unflatten_params reflects values from new_flat_param
new_flat_param = torch.full_like(module.flat_param, fill_value=42.0)
with module.unflatten_params(flat_param=new_flat_param):
new_state_dict = clone_state_dict()
assert new_state_dict.keys() == ref_state_dict.keys()
for k, v in new_state_dict.items():
if k in buffers: # buffers are not changed
torch.testing.assert_allclose(v, ref_state_dict[k])
else: # params reflect new_flat_param value
assert torch.all(v == 42.0)
# after context manager exits, we go back to previous (reference) state
torch.testing.assert_allclose(module.flat_param, ref_flat_param)
with module.unflatten_params():
ref_state_dict2 = clone_state_dict()
assert objects_are_equal(ref_state_dict, ref_state_dict2)
# if we load the new_state_dict, then the flat param should match new_flat_param
module.load_state_dict(new_state_dict)
torch.testing.assert_allclose(module.flat_param, new_flat_param)
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestFlattenParamsCUDA(TestFlattenParams):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment