Unverified Commit d2924670 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] Make state_dict all-gather FP32 params (#451)

parent f3359550
...@@ -180,7 +180,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -180,7 +180,10 @@ class FullyShardedDataParallel(nn.Module):
params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded")) params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))
self._has_params = len(params) > 0 self._has_params = len(params) > 0
if self.flatten_parameters and self._has_params: if not self._has_params:
self.flatten_parameters = False
if self.flatten_parameters:
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params) self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
del module # free original module in case it helps garbage collection del module # free original module in case it helps garbage collection
self.params = [self._fsdp_wrapped_module.flat_param] self.params = [self._fsdp_wrapped_module.flat_param]
...@@ -335,22 +338,27 @@ class FullyShardedDataParallel(nn.Module): ...@@ -335,22 +338,27 @@ class FullyShardedDataParallel(nn.Module):
continue continue
p._is_sharded = True p._is_sharded = True
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(p.data).chunk(self.world_size))
while len(chunks) < self.world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
assert num_to_pad >= 0, num_to_pad
# Replace p.data with the relevant shard. # Replace p.data with the relevant shard.
orig_data = p.data orig_data = p.data
p.data = chunks[self.rank].clone() # clone since we free storage below p.data = self._get_shard(p.data)
if num_to_pad > 0:
p.data = F.pad(p.data, [0, num_to_pad])
free_storage_(orig_data) free_storage_(orig_data)
def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
"""Return the local shard of a given full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(self.world_size))
while len(chunks) < self.world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
assert num_to_pad >= 0, num_to_pad
shard = chunks[self.rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard
def extra_repr(self) -> str: def extra_repr(self) -> str:
return ( return (
f"rank={self.rank}, world_size={self.world_size}, " f"rank={self.rank}, world_size={self.world_size}, "
...@@ -408,32 +416,34 @@ class FullyShardedDataParallel(nn.Module): ...@@ -408,32 +416,34 @@ class FullyShardedDataParallel(nn.Module):
Returns the whole (unsharded) state of the module. Parameters are not Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors wrapped Module without any sharding-specific logic. Returned tensors
will always be typed float32. will be full precision (e.g., FP32).
.. warning:: This needs to be called on all ranks, since synchronization .. warning:: This needs to be called on all ranks, since synchronization
primitives will be used. primitives will be used.
""" """
torch.cuda.synchronize()
self._lazy_init()
if self.mixed_precision: if self.mixed_precision:
# Buffers dtype stays consistent with parameters. # Buffers dtype stays consistent with parameters.
self._all_buffers_to(dtype=torch.float32) self._all_buffers_to(dtype=torch.float32)
if self._return_full_state_dict: if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS: if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
with self.summon_full_params(): with self.summon_full_params(volatile=True):
state_dict = super().state_dict(*args, **kwargs) state_dict = super().state_dict(*args, **kwargs)
else: else:
torch.cuda.synchronize()
self._lazy_init()
state_dict = super().state_dict(*args, **kwargs) state_dict = super().state_dict(*args, **kwargs)
else: else:
torch.cuda.synchronize()
self._lazy_init()
if self.flatten_parameters: if self.flatten_parameters:
assert isinstance(self.module, FlattenParamsWrapper) assert isinstance(self.module, FlattenParamsWrapper)
state_dict = self.module.flat_state_dict(*args, **kwargs) state_dict = self.module.flat_state_dict(*args, **kwargs)
else: else:
state_dict = super().state_dict(*args, **kwargs) state_dict = super().state_dict(*args, **kwargs)
if self.cpu_offload:
for k in state_dict.keys():
state_dict[k] = state_dict[k].cpu()
if self.mixed_precision: if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to fp16. # In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.compute_dtype) self._all_buffers_to(dtype=self.compute_dtype)
...@@ -516,29 +526,42 @@ class FullyShardedDataParallel(nn.Module): ...@@ -516,29 +526,42 @@ class FullyShardedDataParallel(nn.Module):
m._require_backward_grad_sync = old_flag m._require_backward_grad_sync = old_flag
@contextlib.contextmanager @contextlib.contextmanager
def summon_full_params(self, recurse: bool = True) -> Generator: def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
""" """
A context manager to expose full params for the current FSDP instance. A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params for Can be useful *after* forward/backward for a model to get the params for
additional processing or checking. additional processing or checking. Parameters will be gathered in full
precision (e.g., FP32).
By default this will recursively summon all params for nested FSDP
instances; this can be disabled by setting ``recurse=False``.
.. note:: This can be used on inner FSDPs. .. note:: This can be used on inner FSDPs.
.. note:: This can *not* be used within a forward or backward pass. Nor .. note:: This can *not* be used within a forward or backward pass. Nor
can forward and backward be started from within this context. can forward and backward be started from within this context.
.. note:: The full parameters will be freed after the context manager
exits; it is up to the caller to clone them if needed.
.. note:: The full parameters can be modified, but only the portion
corresponding to the local param shard will persist after the
context manager exits (unless ``volatile=True``, in which case there
are no guarantees about persistence).
Args:
recurse (bool, Optional): recursively summon all params for nested
FSDP instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
""" """
if recurse: if recurse:
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
# summon all params for any nested FlattenParamsWrapper instances # Summon all params for any nested FSDP instances.
for module in self.modules(): for module in self.modules():
if isinstance(module, FullyShardedDataParallel): if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module.summon_full_params(recurse=False)) stack.enter_context(module.summon_full_params(recurse=False, volatile=volatile))
# yield to the caller, with full params in all nested instances # Yield to the caller, with full params in all nested instances.
yield yield
# exiting from the ExitStack will re-shard params # Exiting from the ExitStack will re-shard params.
return return
else: else:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -547,13 +570,30 @@ class FullyShardedDataParallel(nn.Module): ...@@ -547,13 +570,30 @@ class FullyShardedDataParallel(nn.Module):
# Set the state so that we assert when trying to go into # Set the state so that we assert when trying to go into
# forward/backward. # forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS self.training_state = TrainingState.SUMMON_FULL_PARAMS
self._rebuild_full_params() full_tensors = self._rebuild_full_params(full_precision=True)
try: with contextlib.ExitStack() as stack:
yield if self.flatten_parameters and self.module.is_flattened:
finally: # Update flattened views to point to fully-sized tensors. We
self._free_full_params() # use self.params[0] instead of full_tensors since the
self._use_fp32_param_shard() # latter may contain padding.
self.training_state = TrainingState.IDLE assert len(self.params) == 1
assert isinstance(self.module, FlattenParamsWrapper)
stack.enter_context(self.module.unflatten_params(recurse=False, flat_param=self.params[0]))
try:
yield
finally:
stack.close()
assert len(full_tensors) == len(self.params)
for p, (full_tensor, safe_to_free) in zip(self.params, full_tensors):
if not volatile:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None: def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward.""" """Reset instance so :func:`_lazy_init` will run on the next forward."""
...@@ -953,35 +993,61 @@ class FullyShardedDataParallel(nn.Module): ...@@ -953,35 +993,61 @@ class FullyShardedDataParallel(nn.Module):
m.training_state = TrainingState.IDLE m.training_state = TrainingState.IDLE
@torch.no_grad() @torch.no_grad()
def _rebuild_full_params(self) -> None: def _rebuild_full_params(self, full_precision: bool = False) -> List[Tuple[torch.Tensor, bool]]:
"""Gather all shards of params.""" """
Gather all shards of params.
Args:
full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage.
Returns:
a list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param
"""
output_tensors: List[Tuple[torch.Tensor, bool]] = []
with torch.cuda.stream(self._streams["all_gather"]): with torch.cuda.stream(self._streams["all_gather"]):
if self.mixed_precision: if self.mixed_precision and not full_precision:
self._cast_fp32_param_shards_to_fp16() self._cast_fp32_param_shards_to_fp16()
for p in self.params: for p in self.params:
if not p._is_sharded: if not p._is_sharded: # e.g., when world_size == 1
if self.mixed_precision: if self.mixed_precision and not full_precision:
p.data = p._fp16_shard p.data = p._fp16_shard
continue output_tensors.append((p.data, True))
p_size = p._full_param_padded.size()
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
assert p_size.numel() % self.world_size == 0
if p._is_sharded:
# Fill p._full_param_padded with (p.data for each shard in self.world_size)
chunks = list(p._full_param_padded.chunk(self.world_size))
dist.all_gather(chunks, p.data, group=self.process_group)
else: else:
p._full_param_padded.copy_(torch.flatten(p.data), non_blocking=True) output_tensors.append((p.data, False))
continue
p.data = p._full_param_padded[: p._orig_size.numel()].view(p._orig_size) # If self.cpu_offload and full_precision, we need to cast the
# FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device)
if self.mixed_precision: p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
if not self.mixed_precision or not full_precision:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded
else:
# Allocate fresh tensor in full precision.
output_tensor = p_data.new_zeros(p_size)
output_tensors.append((output_tensor, True))
# Fill output_tensor with (p.data for each shard in self.world_size)
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group)
p.data = output_tensor[: p._orig_size.numel()].view(p._orig_size)
if self.mixed_precision and not full_precision:
self._free_fp16_param_shard([p]) self._free_fp16_param_shard([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors
@torch.no_grad() @torch.no_grad()
def _use_full_params(self) -> None: def _use_full_params(self) -> None:
...@@ -1013,7 +1079,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1013,7 +1079,7 @@ class FullyShardedDataParallel(nn.Module):
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self._streams["all_gather"]): with torch.cuda.stream(self._streams["all_gather"]):
for p in params: for p in params:
if not p._is_sharded: if not p._is_sharded: # e.g., world_size == 1
if self.mixed_precision: if self.mixed_precision:
self._free_fp16_param_shard([p]) self._free_fp16_param_shard([p])
continue continue
......
...@@ -53,9 +53,6 @@ class FlattenParamsWrapper(nn.Module): ...@@ -53,9 +53,6 @@ class FlattenParamsWrapper(nn.Module):
self._flatten_params() self._flatten_params()
# register the views as plain attributes
self._unflatten_params_as_views()
# Register hook to be called after state_dict() to remove the # Register hook to be called after state_dict() to remove the
# "_fpw_module." prefix and before load_state_dict() to add it back. # "_fpw_module." prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook) self._register_state_dict_hook(_post_state_dict_hook)
...@@ -70,10 +67,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -70,10 +67,7 @@ class FlattenParamsWrapper(nn.Module):
def module(self) -> nn.Module: def module(self) -> nn.Module:
return self._fpw_module return self._fpw_module
def _flatten_params(self) -> None: def _init_flatten_params(self) -> List[Tensor]:
assert not self.is_flattened
self.is_flattened = True
param_infos = [] param_infos = []
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {} shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {}
shared_param_infos = [] shared_param_infos = []
...@@ -102,11 +96,22 @@ class FlattenParamsWrapper(nn.Module): ...@@ -102,11 +96,22 @@ class FlattenParamsWrapper(nn.Module):
self._param_numels = tuple(param_numels) self._param_numels = tuple(param_numels)
self._param_shapes = tuple(param_shapes) self._param_shapes = tuple(param_shapes)
return params
def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None:
assert not self.is_flattened
self.is_flattened = True
if not hasattr(self, "_param_infos"):
assert flat_param is None
params = self._init_flatten_params()
flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
self.param_numel = flat_param.numel()
del params
# flatten # flatten
flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0)) assert flat_param is not None
self.register_parameter("flat_param", flat_param) self.register_parameter("flat_param", flat_param)
self.param_numel = flat_param.numel()
del params
# deregister the names as parameters # deregister the names as parameters
for m, n in self._param_infos: for m, n in self._param_infos:
...@@ -114,14 +119,18 @@ class FlattenParamsWrapper(nn.Module): ...@@ -114,14 +119,18 @@ class FlattenParamsWrapper(nn.Module):
for m, n, _, _ in self._shared_param_infos: for m, n, _, _ in self._shared_param_infos:
delattr(m, n) delattr(m, n)
def _get_param_views(self) -> Generator: # register the views as plain attributes
return (t.view(s) for (t, s) in zip(self.flat_param.split(self._param_numels), self._param_shapes)) self._unflatten_params_as_views()
def _unflatten_params(self) -> None: def _get_param_views(self, flat_param: Tensor) -> Generator:
assert self.is_flattened return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:
assert self.is_flattened or flat_param is not None
self.is_flattened = False self.is_flattened = False
flat_param = flat_param if flat_param is not None else self.flat_param
ps = self._get_param_views() ps = self._get_param_views(flat_param)
for (m, n), p in zip(self._param_infos, ps): for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n): if hasattr(m, n):
delattr(m, n) delattr(m, n)
...@@ -130,41 +139,60 @@ class FlattenParamsWrapper(nn.Module): ...@@ -130,41 +139,60 @@ class FlattenParamsWrapper(nn.Module):
if hasattr(m, n): if hasattr(m, n):
delattr(m, n) delattr(m, n)
m.register_parameter(n, getattr(shared_m, shared_n)) m.register_parameter(n, getattr(shared_m, shared_n))
del self.flat_param if hasattr(self, "flat_param"):
del self.flat_param
def _unflatten_params_as_views(self) -> None: def _unflatten_params_as_views(self) -> None:
assert self.is_flattened assert self.is_flattened
ps = self._get_param_views() ps = self._get_param_views(self.flat_param)
for (m, n), p in zip(self._param_infos, ps): for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos: for (m, n, shared_m, shared_n) in self._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n)) setattr(m, n, getattr(shared_m, shared_n))
@contextmanager @contextmanager
def unflatten_params(self, recurse: bool = True) -> Generator: def unflatten_params(self, recurse: bool = True, flat_param: Optional[Tensor] = None) -> Generator:
""" """
Unflatten params (optionally recursively on all nested instances). Unflatten params. If the current instance is already unflattened, then
If the current instance is already unflattened, then it will remain it will remain unflattened after the context manager exits.
unflattened after the context manager exits.
Args:
recurse (bool, Optional): recursively unflatten all nested instances
(default: True)
flat_param (Tensor, Optional): flat param to use for unflattening.
If provided, the current instance must be in a flattened state
at the start of the context manager. The provided Tensor must be
appropriately sized and will only be used within the context
manager. After the context manager exits, we will revert to
using ``self.flat_param`` (default: None).
""" """
if recurse: if recurse:
with ExitStack() as stack: with ExitStack() as stack:
# unflatten any nested FlattenParamsWrapper instances # unflatten any nested FlattenParamsWrapper instances
for module in self.modules(): for name, module in self.named_modules():
if isinstance(module, FlattenParamsWrapper): if isinstance(module, FlattenParamsWrapper):
stack.enter_context(module.unflatten_params(recurse=False)) is_self = name == ""
stack.enter_context(
module.unflatten_params(recurse=False, flat_param=flat_param if is_self else None)
)
# yield to the caller, with unflattened params in all nested instances # yield to the caller, with unflattened params in all nested instances
yield yield
# exiting from the ExitStack will re-flatten params # exiting from the ExitStack will re-flatten params
return return
else: else:
assert (
flat_param is None or self.is_flattened
), "Unflattening with custom flat_param requires current instance to be flattened"
orig_flattened = self.is_flattened orig_flattened = self.is_flattened
if self.is_flattened: if orig_flattened:
self._unflatten_params() orig_flat_param = self.flat_param
self._unflatten_params(flat_param)
yield yield
if orig_flattened: if orig_flattened:
self._flatten_params() self._flatten_params(orig_flat_param)
self._unflatten_params_as_views()
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module.""" """Forward missing attributes to wrapped module."""
......
tests/nn/misc/test_flatten_params_wrapper.py tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_checkpoint_activations.py tests/nn/misc/test_checkpoint_activations.py
tests/nn/data_parallel/test_fsdp.py tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/wrap/test_wrap.py tests/nn/wrap/test_wrap.py
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools import functools
import itertools import itertools
from math import inf from math import inf
...@@ -182,8 +183,13 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -182,8 +183,13 @@ class TestComparisonToPyTorchDDP(DistributedTest):
PyTorch DDP vs. FullyShardedDataParallel. PyTorch DDP vs. FullyShardedDataParallel.
""" """
def test_nested_all_wrapped_model(self): @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
config = {"mixed_precision": True} def test_nested_wrapped_model(self, config):
test_fn = functools.partial(self._test_identical_outputs, NestedWrappedModule, config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_all_wrapped_model(self, config):
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True) model_fn = functools.partial(NestedWrappedModule, wrap_everything=True)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config) test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -280,6 +286,9 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -280,6 +286,9 @@ class TestComparisonToPyTorchDDP(DistributedTest):
model = ref_ddp_fn(model, group) model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict() ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()
# Confirm we get the same behavior using FullyShardedDataParallel. # Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
...@@ -456,9 +465,8 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -456,9 +465,8 @@ class TestSaveLoadStateDict(DistributedTest):
def _test_state_dict_before_forward(cls, config, rank, group): def _test_state_dict_before_forward(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
sd = ddp_model.state_dict() sd = ddp_model.state_dict()
expected_dtype = torch.float16 if ddp_model.mixed_precision else torch.float32
wt = sd["embed_tokens.weight"] wt = sd["embed_tokens.weight"]
assert wt.dtype == expected_dtype, f"got dtype {wt.dtype} expected {expected_dtype}" assert wt.dtype == torch.float32, f"got dtype {wt.dtype} expected torch.float32"
cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision) cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision)
@classmethod @classmethod
...@@ -480,15 +488,11 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -480,15 +488,11 @@ class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model(self, config): def test_nested_wrapped_model(self, config):
if config["mixed_precision"]:
return # TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn = functools.partial(self._test_nested_wrapped_model, config=config) test_fn = functools.partial(self._test_nested_wrapped_model, config=config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model_local_state_dict(self, config): def test_nested_wrapped_model_local_state_dict(self, config):
if config["mixed_precision"]:
return # TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn = functools.partial(self._test_nested_wrapped_model_local_state_dict, config=config) test_fn = functools.partial(self._test_nested_wrapped_model_local_state_dict, config=config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -501,6 +505,8 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -501,6 +505,8 @@ class TestSaveLoadStateDict(DistributedTest):
ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()} ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()}
# Create a nested FSDP-wrapped instance. # Create a nested FSDP-wrapped instance.
if config["mixed_precision"]:
config["compute_dtype"] = torch.float32
model = NestedWrappedModule(group, config) model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda() model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import gc
import unittest
from parameterized import parameterized
import torch
from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init
def get_cuda_mem():
torch.cuda.synchronize()
gc.collect()
return torch.cuda.memory_allocated()
class TestMemory(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_memory(self, config):
spawn_and_init(functools.partial(self._test_memory, config))
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_memory_volatile(self, config):
spawn_and_init(functools.partial(self._test_memory, config, volatile=True))
@classmethod
def _test_memory(self, config, rank, group, volatile=False):
model = self.get_wrapped_model(group, cuda_first=False, config=config)
self._train_for_several_steps(model, 1, autocast=model.mixed_precision)
mems = [get_cuda_mem()]
with model.summon_full_params(volatile=volatile):
mems.append(get_cuda_mem())
assert mems[1] >= mems[0]
state_dict = model.state_dict()
mems.append(get_cuda_mem())
assert mems[2] >= mems[1]
mems.append(get_cuda_mem())
assert mems[3] <= mems[2]
del state_dict
mems.append(get_cuda_mem())
assert mems[4] == mems[0]
class TestPersistence(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_non_volatile(self, config):
spawn_and_init(functools.partial(self._test_persistence, config))
@classmethod
def _test_persistence(self, config, rank, group, volatile=False):
model = self.get_wrapped_model(group, cuda_first=False, config=config)
with model.summon_full_params(volatile=False):
model.module.embed_tokens.weight.data.fill_(42)
with model.summon_full_params():
# non-volatile changes are persisted
assert torch.all(model.module.embed_tokens.weight.data == 42.0)
if __name__ == "__main__":
unittest.main()
...@@ -29,17 +29,26 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test ...@@ -29,17 +29,26 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
my_lr = 0.1 my_lr = 0.1
device = torch.device("cuda")
if fsdp_config.get("mixed_precision", False):
dtype = torch.float16
fsdp_config["fp32_reduce_scatter"] = True
else:
dtype = torch.float32
if test_case["assert_ref_out"]: if test_case["assert_ref_out"]:
with torch.no_grad(): with torch.no_grad():
# Compute one iteration local output. # Compute one iteration local output.
weight = model.weight.T.clone().cuda() fp32_weight = model.weight.T.clone().to(device)
v = torch.Tensor(test_case["inputs"][0][rank]).cuda() weight = fp32_weight.to(dtype)
v = torch.Tensor(test_case["inputs"][0][rank]).to(device, dtype)
ref_forward_output_my_rank = torch.matmul(v, weight) ref_forward_output_my_rank = torch.matmul(v, weight)
# Compute one iteration global weight update. # Compute one iteration global weight update.
v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda() v = torch.Tensor(test_case["inputs"][0][:world_size]).to(device, dtype)
grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size) grad = v.float().sum(0).repeat(weight.shape[0], 1).div(world_size)
ref_weight_out = weight - grad.T * my_lr ref_weight_out = fp32_weight - grad.T * my_lr
model.to("cuda") assert ref_weight_out.dtype == torch.float32
model.to(device) # not dtype, since FSDP will manage mixed precision internally
assert isinstance(fsdp_config, dict), str(fsdp_config) assert isinstance(fsdp_config, dict), str(fsdp_config)
model = FSDP(model, **fsdp_config) model = FSDP(model, **fsdp_config)
optim = SGD(model.parameters(), lr=my_lr) optim = SGD(model.parameters(), lr=my_lr)
...@@ -47,9 +56,9 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test ...@@ -47,9 +56,9 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs) == 1 or not test_case["assert_ref_out"]
assert len(inputs[0]) >= world_size assert len(inputs[0]) >= world_size
for in_data in inputs: for in_data in inputs:
in_data = Tensor(in_data[rank]).cuda() in_data = Tensor(in_data[rank]).to(device, dtype)
out = model(in_data) out = model(in_data)
out.sum().backward() out.float().sum().backward()
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
if test_case["assert_ref_out"]: if test_case["assert_ref_out"]:
...@@ -70,7 +79,7 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test ...@@ -70,7 +79,7 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}]) @pytest.mark.parametrize("test_case", [{"inputs": [torch.rand(8, 3)], "assert_ref_out": True}])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_config", [{}, {"flatten_parameters": False}], "fsdp_config", [{}, {"flatten_parameters": False}, {"mixed_precision": True}],
) )
@pytest.mark.parametrize("world_size", list(range(2, 9))) @pytest.mark.parametrize("world_size", list(range(2, 9)))
def test_one_iteration(world_size, test_case, fsdp_config): def test_one_iteration(world_size, test_case, fsdp_config):
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
Test FlattenParamsWrapper Test FlattenParamsWrapper
""" """
from collections import OrderedDict
import unittest import unittest
import torch import torch
...@@ -196,6 +197,40 @@ class TestFlattenParams(unittest.TestCase): ...@@ -196,6 +197,40 @@ class TestFlattenParams(unittest.TestCase):
assert objects_are_equal(ref_output, new_output) assert objects_are_equal(ref_output, new_output)
def test_unflatten_params(self):
for module_init_fn in self._get_module_init_fns():
module = FlattenParamsWrapper(module_init_fn())
buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()}
def clone_state_dict():
return OrderedDict((k, v.clone()) for k, v in module.state_dict().items())
ref_flat_param = module.flat_param.clone()
with module.unflatten_params():
ref_state_dict = clone_state_dict()
assert not torch.all(ref_flat_param == 0)
# confirm that unflatten_params reflects values from new_flat_param
new_flat_param = torch.full_like(module.flat_param, fill_value=42.0)
with module.unflatten_params(flat_param=new_flat_param):
new_state_dict = clone_state_dict()
assert new_state_dict.keys() == ref_state_dict.keys()
for k, v in new_state_dict.items():
if k in buffers: # buffers are not changed
torch.testing.assert_allclose(v, ref_state_dict[k])
else: # params reflect new_flat_param value
assert torch.all(v == 42.0)
# after context manager exits, we go back to previous (reference) state
torch.testing.assert_allclose(module.flat_param, ref_flat_param)
with module.unflatten_params():
ref_state_dict2 = clone_state_dict()
assert objects_are_equal(ref_state_dict, ref_state_dict2)
# if we load the new_state_dict, then the flat param should match new_flat_param
module.load_state_dict(new_state_dict)
torch.testing.assert_allclose(module.flat_param, new_flat_param)
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestFlattenParamsCUDA(TestFlattenParams): class TestFlattenParamsCUDA(TestFlattenParams):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment