Unverified Commit 14abed6e authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[FSDP] [feat] Add state_dict_device option (#579)

parent 121b9db0
...@@ -163,6 +163,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -163,6 +163,10 @@ class FullyShardedDataParallel(nn.Module):
with the proper state at each rank. This is useful for situations, like Mixture Of Experts, with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node. where all but a few parameters can fit on one node.
Default: False Default: False
state_dict_device (torch.device, Optional):
device for parameters returned by :func:`state_dict`. If not given,
this will default to ``compute_dtype``. Note that only the device
type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
""" """
def __init__( def __init__(
...@@ -180,6 +184,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -180,6 +184,7 @@ class FullyShardedDataParallel(nn.Module):
bucket_cap_mb: int = 25, bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None, compute_device: Optional[torch.device] = None,
no_broadcast_optim_state: Optional[bool] = False, no_broadcast_optim_state: Optional[bool] = False,
state_dict_device: Optional[torch.device] = None,
): ):
super().__init__() super().__init__()
self.process_group = process_group or dist.new_group() self.process_group = process_group or dist.new_group()
...@@ -194,26 +199,21 @@ class FullyShardedDataParallel(nn.Module): ...@@ -194,26 +199,21 @@ class FullyShardedDataParallel(nn.Module):
self.buffer_dtype = buffer_dtype or self.compute_dtype self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb self.bucket_cap_mb = bucket_cap_mb
self.compute_device = compute_device or _get_default_cuda_device(module)
self.uncollected_opt_state: Dict[int, Dict] = {} self.uncollected_opt_state: Dict[int, Dict] = {}
self.no_broadcast_optim_state = no_broadcast_optim_state self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size) self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.numel_padded_per_param: List[int] = [] self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device
if self.fp32_reduce_scatter and not self.mixed_precision: if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True") raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision: if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True") raise ValueError("cpu_offload requires mixed_precision=True")
if self.compute_device is None:
# Try to infer CUDA device from module parameters.
self.compute_device = next(module.parameters()).device
if self.compute_device.type != "cuda":
# Fall back to current CUDA device.
self.compute_device = torch.device("cuda")
validate_process_group(self.compute_device, self.process_group) validate_process_group(self.compute_device, self.process_group)
enable_pytorch_sync_bn(module) enable_pytorch_sync_bn(module)
...@@ -545,7 +545,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -545,7 +545,7 @@ class FullyShardedDataParallel(nn.Module):
if self._return_full_state_dict: if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS: if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
with self.summon_full_params(volatile=True): with self.summon_full_params(recurse=False, volatile=True):
state_dict = super().state_dict(*args, **kwargs) state_dict = super().state_dict(*args, **kwargs)
else: else:
state_dict = super().state_dict(*args, **kwargs) state_dict = super().state_dict(*args, **kwargs)
...@@ -1410,7 +1410,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1410,7 +1410,7 @@ class FullyShardedDataParallel(nn.Module):
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances] sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
else: else:
sd = dummy_tensor # type: ignore sd = dummy_tensor # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device) # type: ignore sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device)
if should_collect_state: if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict" assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu"))) all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))
...@@ -1501,6 +1501,15 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1501,6 +1501,15 @@ class FullyShardedDataParallel(nn.Module):
return full_optim_state_dict return full_optim_state_dict
def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters."""
compute_device = next(module.parameters()).device
if compute_device.type != "cuda":
# Fall back to current CUDA device.
compute_device = torch.device("cuda")
return compute_device
@torch.no_grad() @torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
""" """
...@@ -1534,13 +1543,25 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None: ...@@ -1534,13 +1543,25 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
def _post_state_dict_hook( def _post_state_dict_hook(
module: nn.Module, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any module: FullyShardedDataParallel, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any
) -> "OrderedDict[str, torch.Tensor]": ) -> "OrderedDict[str, torch.Tensor]":
if module.training_state == TrainingState.SUMMON_FULL_PARAMS: # Assuming we are in a ``summon_full_params()`` context, we need to clone
# We copy the state_dict since full param will be freed after # each tensor so that it does not get freed (in-place) when the context
# we exit the summon_full_params() context. # exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# mostonce. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
for key in state_dict.keys(): for key in state_dict.keys():
if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False):
continue
if state_dict[key].device.type != module.state_dict_device.type:
state_dict[key] = state_dict[key].to(device=module.state_dict_device)
state_dict[key]._has_been_cloned = True
elif module.training_state == TrainingState.SUMMON_FULL_PARAMS:
# We copy the state_dict since full param will be freed after we
# exit the ``summon_full_params()`` context.
state_dict[key] = state_dict[key].clone() state_dict[key] = state_dict[key].clone()
state_dict[key]._has_been_cloned = True
# Remove "_fsdp_wrapped_module." prefix # Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix) replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix)
......
...@@ -117,6 +117,10 @@ class Tensor: ...@@ -117,6 +117,10 @@ class Tensor:
data: Tensor = ... data: Tensor = ...
names: List[str] = ... names: List[str] = ...
#MODIFIED BY FULLY_SHARDED_DATA_PARALLEL
_has_been_cloned: Optional[bool] = ...
#END
def __init__(self, *args, **kwargs) -> None: ... def __init__(self, *args, **kwargs) -> None: ...
@property @property
......
...@@ -38,3 +38,4 @@ tests/nn/moe/test_top2gating.py ...@@ -38,3 +38,4 @@ tests/nn/moe/test_top2gating.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py tests/experimental/nn/test_offload.py
tests/nn/data_parallel/test_fsdp_apply.py tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
...@@ -440,160 +440,6 @@ class TestSerialization(DistributedTest): ...@@ -440,160 +440,6 @@ class TestSerialization(DistributedTest):
optim.step() optim.step()
class TestLocalStateDict(DistributedTest):
@parameterized.expand([[True, True], [False, False]], name_func=rename_test)
def test_load_local_state_dict(self, flatten_params, mixed_precision):
test_fn = functools.partial(
self._load_local_and_train, {"flatten_parameters": flatten_params, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = self.get_wrapped_model(
group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model, add_bn=False
) # Set bn=True here to show that BN doesn't get updated
state_1 = model.local_state_dict()
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
assert len(state_1) > 0
model.load_local_state_dict(state_1)
weight_key = "flat_param" if model.flatten_parameters else "embed_tokens.weight"
state_1_weight = state_1[weight_key]
assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32"
if not model.flatten_parameters:
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 1, model.mixed_precision)
state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
model.load_local_state_dict(state_2)
assert state_1.keys() == state_2.keys()
# Assert that parameters were updated since before training
unchanged = []
unwrapped_model = model.module.module if config["flatten_parameters"] else model.module
buffers = {name for name, _ in unwrapped_model.named_buffers()}
for k in state_1:
if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers):
unchanged.append(k)
if unchanged:
raise AssertionError(f"params {unchanged} not changed after training")
class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_calling_state_dict_twice_mixed_precision(self, mixed_precision):
test_fn = functools.partial(
self._test_calling_state_dict_twice, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_calling_state_dict_twice(self, config, rank, group, **model_kwargs):
ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config, **model_kwargs)
autocast = ddp_model.mixed_precision
self._train_for_several_steps(ddp_model, 1, autocast)
ddp_model.state_dict()
ddp_model.state_dict() # second call
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_state_dict_after_forward(self, config):
test_fn = functools.partial(self._test_module_state_dict, config)
spawn_and_init(test_fn)
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_state_dict_before_forward(self, mixed_precision):
test_fn = functools.partial(
self._test_state_dict_before_forward, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_state_dict_before_forward(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
sd = ddp_model.state_dict()
wt = sd["embed_tokens.weight"]
assert wt.dtype == torch.float32, f"got dtype {wt.dtype} expected torch.float32"
cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision)
@classmethod
def _test_module_state_dict(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
autocast = ddp_model.mixed_precision
cls._train_for_several_steps(ddp_model, 2, autocast)
state_1 = ddp_model.state_dict()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
unwrapped_model = TransformerWithSharedParams(group)
unwrapped_model.load_state_dict(state_1)
new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda()
cls._train_for_several_steps(new_ddp_model, 2, autocast)
try:
ddp_model.load_state_dict(new_ddp_model.state_dict())
assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except Exception:
pass
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model(self, config):
test_fn = functools.partial(self._test_nested_wrapped_model, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model_local_state_dict(self, config):
test_fn = functools.partial(self._test_nested_wrapped_model_local_state_dict, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_nested_wrapped_model(cls, rank, group, config=None):
# Get reference state dict without any nested FSDP instances.
model = NestedWrappedModule(group, None).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group)
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()}
# Create a nested FSDP-wrapped instance.
if config["mixed_precision"]:
config["compute_dtype"] = torch.float32
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round-trip state dict save/load/save.
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
model.load_state_dict(state_dict)
state_dict = model.state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
@classmethod
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None):
# Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round trip state dict save/load/save.
ref_state_dict = {k: v.clone() for k, v in model.local_state_dict().items()}
model.load_local_state_dict(ref_state_dict)
state_dict = model.local_state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
class TestHooks(DistributedTest): class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes. # Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used # They aspire to make sure that backward hooks are registered and used
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import unittest
from parameterized import parameterized
import torch
from torch import nn
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import objects_are_equal
from .test_fsdp import (
CONFIG_OPTIONS,
DistributedTest,
NestedWrappedModule,
TransformerWithSharedParams,
rename_test,
spawn_and_init,
)
class TestLocalStateDict(DistributedTest):
@parameterized.expand([[True, True], [False, False]], name_func=rename_test)
def test_load_local_state_dict(self, flatten_params, mixed_precision):
test_fn = functools.partial(
self._load_local_and_train, {"flatten_parameters": flatten_params, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=23):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model = self.get_wrapped_model(
group, cuda_first=False, config=config, d_vocab=d_vocab, d_model=d_model, add_bn=False
) # Set bn=True here to show that BN doesn't get updated
state_1 = model.local_state_dict()
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
assert len(state_1) > 0
model.load_local_state_dict(state_1)
weight_key = "flat_param" if model.flatten_parameters else "embed_tokens.weight"
state_1_weight = state_1[weight_key]
assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32"
if not model.flatten_parameters:
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 1, model.mixed_precision)
state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
model.load_local_state_dict(state_2)
assert state_1.keys() == state_2.keys()
# Assert that parameters were updated since before training
unchanged = []
unwrapped_model = model.module.module if config["flatten_parameters"] else model.module
buffers = {name for name, _ in unwrapped_model.named_buffers()}
for k in state_1:
if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers):
unchanged.append(k)
if unchanged:
raise AssertionError(f"params {unchanged} not changed after training")
class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_calling_state_dict_twice_mixed_precision(self, mixed_precision):
test_fn = functools.partial(
self._test_calling_state_dict_twice, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_calling_state_dict_twice(self, config, rank, group, **model_kwargs):
ddp_model = self.get_wrapped_model(group, cuda_first=False, config=config, **model_kwargs)
autocast = ddp_model.mixed_precision
self._train_for_several_steps(ddp_model, 1, autocast)
ddp_model.state_dict()
ddp_model.state_dict() # second call
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_state_dict_after_forward(self, config):
test_fn = functools.partial(self._test_module_state_dict, config)
spawn_and_init(test_fn)
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_state_dict_before_forward(self, mixed_precision):
test_fn = functools.partial(
self._test_state_dict_before_forward, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@classmethod
def _test_state_dict_before_forward(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
sd = ddp_model.state_dict()
wt = sd["embed_tokens.weight"]
assert wt.dtype == torch.float32, f"got dtype {wt.dtype} expected torch.float32"
cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision)
@classmethod
def _test_module_state_dict(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
autocast = ddp_model.mixed_precision
cls._train_for_several_steps(ddp_model, 2, autocast)
state_1 = ddp_model.state_dict()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
unwrapped_model = TransformerWithSharedParams(group)
unwrapped_model.load_state_dict(state_1)
new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda()
cls._train_for_several_steps(new_ddp_model, 2, autocast)
try:
ddp_model.load_state_dict(new_ddp_model.state_dict())
assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except Exception:
pass
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model(self, config):
test_fn = functools.partial(self._test_nested_wrapped_model, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model_local_state_dict(self, config):
test_fn = functools.partial(self._test_nested_wrapped_model_local_state_dict, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_nested_wrapped_model(cls, rank, group, config=None):
# Get reference state dict without any nested FSDP instances.
model = NestedWrappedModule(group, None).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group)
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()}
# Create a nested FSDP-wrapped instance.
if config["mixed_precision"]:
config["compute_dtype"] = torch.float32
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round-trip state dict save/load/save.
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
model.load_state_dict(state_dict)
state_dict = model.state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
@classmethod
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None):
# Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round trip state dict save/load/save.
ref_state_dict = {k: v.clone() for k, v in model.local_state_dict().items()}
model.load_local_state_dict(ref_state_dict)
state_dict = model.local_state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
class TestStateDictDeviceDtype(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device(self, mixed_precision, cpu_offload):
test_fn = functools.partial(
self._test_state_dict_device, {"cpu_offload": cpu_offload, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn)
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device_cuda(self, mixed_precision, cpu_offload):
test_fn = functools.partial(
self._test_state_dict_device,
{"cpu_offload": cpu_offload, "mixed_precision": mixed_precision, "state_dict_device": torch.device("cuda")},
)
spawn_and_init(test_fn)
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device_cpu(self, mixed_precision, cpu_offload):
test_fn = functools.partial(
self._test_state_dict_device,
{"cpu_offload": cpu_offload, "mixed_precision": mixed_precision, "state_dict_device": torch.device("cpu")},
)
spawn_and_init(test_fn)
def test_state_dict_device_pure_fp16(self):
test_fn = functools.partial(
self._test_state_dict_device,
{"cpu_offload": False, "mixed_precision": False, "compute_dtype": torch.float16},
# pure_fp16 is similar to the --memory-efficient-fp16 option in fairseq
pure_fp16=True,
)
spawn_and_init(test_fn)
@classmethod
def _test_state_dict_device(self, config, rank, group, pure_fp16=False, **model_kwargs):
model = TransformerWithSharedParams(group, **model_kwargs)
if pure_fp16:
assert not config["mixed_precision"]
model = model.half()
fsdp_model = FullyShardedDataParallel(model, group, **config)
if not config["cpu_offload"]:
fsdp_model = fsdp_model.cuda()
autocast = fsdp_model.mixed_precision or pure_fp16
self._train_for_several_steps(fsdp_model, 1, autocast)
sd = fsdp_model.state_dict()
sd_device = config.get("state_dict_device")
for k, v in sd.items():
if config["cpu_offload"] or (sd_device is not None and sd_device.type == "cpu"):
assert v.device.type == "cpu", v.device.type
else:
assert v.device.type == "cuda", v.device.type
expected_dtype = torch.float16 if pure_fp16 else torch.float32
buffers = {
k.replace("_fsdp_wrapped_module.", "").replace("_fpw_module.", "") for k, _ in fsdp_model.named_buffers()
}
for k, v in sd.items():
if not torch.is_floating_point(v):
continue
if k in buffers:
assert v.dtype == fsdp_model.buffer_dtype, f"{v.dtype} != {fsdp_model.buffer_dtype}"
else:
assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment