Unverified Commit 14abed6e authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[FSDP] [feat] Add state_dict_device option (#579)

parent 121b9db0
......@@ -163,6 +163,10 @@ class FullyShardedDataParallel(nn.Module):
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node.
Default: False
state_dict_device (torch.device, Optional):
device for parameters returned by :func:`state_dict`. If not given,
this will default to ``compute_dtype``. Note that only the device
type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
"""
def __init__(
......@@ -180,6 +184,7 @@ class FullyShardedDataParallel(nn.Module):
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
no_broadcast_optim_state: Optional[bool] = False,
state_dict_device: Optional[torch.device] = None,
):
super().__init__()
self.process_group = process_group or dist.new_group()
......@@ -194,26 +199,21 @@ class FullyShardedDataParallel(nn.Module):
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.compute_device = compute_device or _get_default_cuda_device(module)
self.uncollected_opt_state: Dict[int, Dict] = {}
self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True")
if self.compute_device is None:
# Try to infer CUDA device from module parameters.
self.compute_device = next(module.parameters()).device
if self.compute_device.type != "cuda":
# Fall back to current CUDA device.
self.compute_device = torch.device("cuda")
validate_process_group(self.compute_device, self.process_group)
enable_pytorch_sync_bn(module)
......@@ -545,7 +545,7 @@ class FullyShardedDataParallel(nn.Module):
if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
with self.summon_full_params(volatile=True):
with self.summon_full_params(recurse=False, volatile=True):
state_dict = super().state_dict(*args, **kwargs)
else:
state_dict = super().state_dict(*args, **kwargs)
......@@ -1410,7 +1410,7 @@ class FullyShardedDataParallel(nn.Module):
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device) # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device)
if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))
......@@ -1501,6 +1501,15 @@ class FullyShardedDataParallel(nn.Module):
return full_optim_state_dict
def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters."""
compute_device = next(module.parameters()).device
if compute_device.type != "cuda":
# Fall back to current CUDA device.
compute_device = torch.device("cuda")
return compute_device
@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
"""
......@@ -1534,13 +1543,25 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
def _post_state_dict_hook(
module: nn.Module, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any
module: FullyShardedDataParallel, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any
) -> "OrderedDict[str, torch.Tensor]":
if module.training_state == TrainingState.SUMMON_FULL_PARAMS:
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
# Assuming we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# mostonce. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
for key in state_dict.keys():
if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False):
continue
if state_dict[key].device.type != module.state_dict_device.type:
state_dict[key] = state_dict[key].to(device=module.state_dict_device)
state_dict[key]._has_been_cloned = True
elif module.training_state == TrainingState.SUMMON_FULL_PARAMS:
# We copy the state_dict since full param will be freed after we
# exit the ``summon_full_params()`` context.
state_dict[key] = state_dict[key].clone()
state_dict[key]._has_been_cloned = True
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix)
......
......@@ -117,6 +117,10 @@ class Tensor:
data: Tensor = ...
names: List[str] = ...
#MODIFIED BY FULLY_SHARDED_DATA_PARALLEL
_has_been_cloned: Optional[bool] = ...
#END
def __init__(self, *args, **kwargs) -> None: ...
@property
......
......@@ -38,3 +38,4 @@ tests/nn/moe/test_top2gating.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
......@@ -440,160 +440,6 @@ class TestSerialization(DistributedTest):
optim.step()
class TestLocalStateDict(DistributedTest):
@parameterized.expand([[True, True], [False, False]], name_func=rename_test)
def test_load_local_state_dict(self, flatten_params, mixed_precision):
test_fn = functools.partial(
self._load_local_and_train, {"flatten_parameters": flatten_params, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = self.get_wrapped_model(
group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model, add_bn=False
) # Set bn=True here to show that BN doesn't get updated
state_1 = model.local_state_dict()
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
assert len(state_1) > 0
model.load_local_state_dict(state_1)
weight_key = "flat_param" if model.flatten_parameters else "embed_tokens.weight"
state_1_weight = state_1[weight_key]
assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32"
if not model.flatten_parameters:
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 1, model.mixed_precision)
state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
model.load_local_state_dict(state_2)
assert state_1.keys() == state_2.keys()
# Assert that parameters were updated since before training
unchanged = []
unwrapped_model = model.module.module if config["flatten_parameters"] else model.module
buffers = {name for name, _ in unwrapped_model.named_buffers()}
for k in state_1:
if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers):
unchanged.append(k)
if unchanged:
raise AssertionError(f"params {unchanged} not changed after training")
class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_calling_state_dict_twice_mixed_precision(self, mixed_precision):
test_fn = functools.partial(
self._test_calling_state_dict_twice, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_calling_state_dict_twice(self, config, rank, group, **model_kwargs):
ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config, **model_kwargs)
autocast = ddp_model.mixed_precision
self._train_for_several_steps(ddp_model, 1, autocast)
ddp_model.state_dict()
ddp_model.state_dict() # second call
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_state_dict_after_forward(self, config):
test_fn = functools.partial(self._test_module_state_dict, config)
spawn_and_init(test_fn)
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_state_dict_before_forward(self, mixed_precision):
test_fn = functools.partial(
self._test_state_dict_before_forward, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_state_dict_before_forward(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
sd = ddp_model.state_dict()
wt = sd["embed_tokens.weight"]
assert wt.dtype == torch.float32, f"got dtype {wt.dtype} expected torch.float32"
cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision)
@classmethod
def _test_module_state_dict(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
autocast = ddp_model.mixed_precision
cls._train_for_several_steps(ddp_model, 2, autocast)
state_1 = ddp_model.state_dict()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
unwrapped_model = TransformerWithSharedParams(group)
unwrapped_model.load_state_dict(state_1)
new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda()
cls._train_for_several_steps(new_ddp_model, 2, autocast)
try:
ddp_model.load_state_dict(new_ddp_model.state_dict())
assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except Exception:
pass
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model(self, config):
test_fn = functools.partial(self._test_nested_wrapped_model, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model_local_state_dict(self, config):
test_fn = functools.partial(self._test_nested_wrapped_model_local_state_dict, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_nested_wrapped_model(cls, rank, group, config=None):
# Get reference state dict without any nested FSDP instances.
model = NestedWrappedModule(group, None).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group)
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()}
# Create a nested FSDP-wrapped instance.
if config["mixed_precision"]:
config["compute_dtype"] = torch.float32
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round-trip state dict save/load/save.
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
model.load_state_dict(state_dict)
state_dict = model.state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
@classmethod
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None):
# Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round trip state dict save/load/save.
ref_state_dict = {k: v.clone() for k, v in model.local_state_dict().items()}
model.load_local_state_dict(ref_state_dict)
state_dict = model.local_state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import unittest
from parameterized import parameterized
import torch
from torch import nn
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import objects_are_equal
from .test_fsdp import (
CONFIG_OPTIONS,
DistributedTest,
NestedWrappedModule,
TransformerWithSharedParams,
rename_test,
spawn_and_init,
)
class TestLocalStateDict(DistributedTest):
@parameterized.expand([[True, True], [False, False]], name_func=rename_test)
def test_load_local_state_dict(self, flatten_params, mixed_precision):
test_fn = functools.partial(
self._load_local_and_train, {"flatten_parameters": flatten_params, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = self.get_wrapped_model(
group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model, add_bn=False
) # Set bn=True here to show that BN doesn't get updated
state_1 = model.local_state_dict()
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
assert len(state_1) > 0
model.load_local_state_dict(state_1)
weight_key = "flat_param" if model.flatten_parameters else "embed_tokens.weight"
state_1_weight = state_1[weight_key]
assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32"
if not model.flatten_parameters:
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 1, model.mixed_precision)
state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
model.load_local_state_dict(state_2)
assert state_1.keys() == state_2.keys()
# Assert that parameters were updated since before training
unchanged = []
unwrapped_model = model.module.module if config["flatten_parameters"] else model.module
buffers = {name for name, _ in unwrapped_model.named_buffers()}
for k in state_1:
if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers):
unchanged.append(k)
if unchanged:
raise AssertionError(f"params {unchanged} not changed after training")
class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_calling_state_dict_twice_mixed_precision(self, mixed_precision):
test_fn = functools.partial(
self._test_calling_state_dict_twice, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_calling_state_dict_twice(self, config, rank, group, **model_kwargs):
ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config, **model_kwargs)
autocast = ddp_model.mixed_precision
self._train_for_several_steps(ddp_model, 1, autocast)
ddp_model.state_dict()
ddp_model.state_dict() # second call
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_state_dict_after_forward(self, config):
test_fn = functools.partial(self._test_module_state_dict, config)
spawn_and_init(test_fn)
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_state_dict_before_forward(self, mixed_precision):
test_fn = functools.partial(
self._test_state_dict_before_forward, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_state_dict_before_forward(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
sd = ddp_model.state_dict()
wt = sd["embed_tokens.weight"]
assert wt.dtype == torch.float32, f"got dtype {wt.dtype} expected torch.float32"
cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision)
@classmethod
def _test_module_state_dict(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
autocast = ddp_model.mixed_precision
cls._train_for_several_steps(ddp_model, 2, autocast)
state_1 = ddp_model.state_dict()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
unwrapped_model = TransformerWithSharedParams(group)
unwrapped_model.load_state_dict(state_1)
new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda()
cls._train_for_several_steps(new_ddp_model, 2, autocast)
try:
ddp_model.load_state_dict(new_ddp_model.state_dict())
assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except Exception:
pass
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model(self, config):
test_fn = functools.partial(self._test_nested_wrapped_model, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model_local_state_dict(self, config):
test_fn = functools.partial(self._test_nested_wrapped_model_local_state_dict, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_nested_wrapped_model(cls, rank, group, config=None):
# Get reference state dict without any nested FSDP instances.
model = NestedWrappedModule(group, None).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group)
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()}
# Create a nested FSDP-wrapped instance.
if config["mixed_precision"]:
config["compute_dtype"] = torch.float32
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round-trip state dict save/load/save.
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
model.load_state_dict(state_dict)
state_dict = model.state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
@classmethod
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None):
# Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round trip state dict save/load/save.
ref_state_dict = {k: v.clone() for k, v in model.local_state_dict().items()}
model.load_local_state_dict(ref_state_dict)
state_dict = model.local_state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
class TestStateDictDeviceDtype(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device(self, mixed_precision, cpu_offload):
test_fn = functools.partial(
self._test_state_dict_device, {"cpu_offload": cpu_offload, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device_cuda(self, mixed_precision, cpu_offload):
test_fn = functools.partial(
self._test_state_dict_device,
{"cpu_offload": cpu_offload, "mixed_precision": mixed_precision, "state_dict_device": torch.device("cuda")},
)
spawn_and_init(test_fn)
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device_cpu(self, mixed_precision, cpu_offload):
test_fn = functools.partial(
self._test_state_dict_device,
{"cpu_offload": cpu_offload, "mixed_precision": mixed_precision, "state_dict_device": torch.device("cpu")},
)
spawn_and_init(test_fn)
def test_state_dict_device_pure_fp16(self):
test_fn = functools.partial(
self._test_state_dict_device,
{"cpu_offload": False, "mixed_precision": False, "compute_dtype": torch.float16},
# pure_fp16 is similar to the --memory-efficient-fp16 option in fairseq
pure_fp16=True,
)
spawn_and_init(test_fn)
@classmethod
def _test_state_dict_device(self, config, rank, group, pure_fp16=False, **model_kwargs):
model = TransformerWithSharedParams(group, **model_kwargs)
if pure_fp16:
assert not config["mixed_precision"]
model = model.half()
fsdp_model = FullyShardedDataParallel(model, group, **config)
if not config["cpu_offload"]:
fsdp_model = fsdp_model.cuda()
autocast = fsdp_model.mixed_precision or pure_fp16
self._train_for_several_steps(fsdp_model, 1, autocast)
sd = fsdp_model.state_dict()
sd_device = config.get("state_dict_device")
for k, v in sd.items():
if config["cpu_offload"] or (sd_device is not None and sd_device.type == "cpu"):
assert v.device.type == "cpu", v.device.type
else:
assert v.device.type == "cuda", v.device.type
expected_dtype = torch.float16 if pure_fp16 else torch.float32
buffers = {
k.replace("_fsdp_wrapped_module.", "").replace("_fpw_module.", "") for k, _ in fsdp_model.named_buffers()
}
for k, v in sd.items():
if not torch.is_floating_point(v):
continue
if k in buffers:
assert v.dtype == fsdp_model.buffer_dtype, f"{v.dtype} != {fsdp_model.buffer_dtype}"
else:
assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment