"src/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "f33c7fd968b8968674bb4ef07575b7075d60a5f1"
Unverified Commit 25cebf85 authored by Pete's avatar Pete Committed by GitHub
Browse files

Fix buffer dtype in ` FSDP.state_dict()` when using mixed precision (#705)

* add failing test for buffer dtype

* fix buffer dtype issue

* update CHANGELOG

* fix
parent 3443a635
...@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- checkpointing: use dummy tensor to ensure backward pass is called [#701] - checkpointing: use dummy tensor to ensure backward pass is called [#701]
- FSDP: fixed bug where buffers returned in `state_dict()` could still be half precision when `mixed_precision` is set to `True`.
### Added ### Added
......
...@@ -653,17 +653,21 @@ class FullyShardedDataParallel(nn.Module): ...@@ -653,17 +653,21 @@ class FullyShardedDataParallel(nn.Module):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
self._lazy_init() self._lazy_init()
if self.mixed_precision:
# Buffers dtype stays consistent with parameters. def maybe_cast_buffers(dtype: Optional[torch.dtype] = None) -> None:
self._cast_buffers(dtype=torch.float32) if self.mixed_precision:
self._cast_buffers(dtype=dtype)
if self._return_full_state_dict: if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS: if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
with self.summon_full_params(recurse=False, volatile=True): with self.summon_full_params(recurse=False, volatile=True):
maybe_cast_buffers(torch.float32)
state_dict = super().state_dict(*args, **kwargs) state_dict = super().state_dict(*args, **kwargs)
else: else:
maybe_cast_buffers(torch.float32)
state_dict = super().state_dict(*args, **kwargs) state_dict = super().state_dict(*args, **kwargs)
else: else:
maybe_cast_buffers(torch.float32)
if self.flatten_parameters: if self.flatten_parameters:
assert isinstance(self.module, FlattenParamsWrapper) assert isinstance(self.module, FlattenParamsWrapper)
state_dict = self.module.flat_state_dict(*args, **kwargs) state_dict = self.module.flat_state_dict(*args, **kwargs)
...@@ -674,9 +678,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -674,9 +678,8 @@ class FullyShardedDataParallel(nn.Module):
for k in state_dict.keys(): for k in state_dict.keys():
state_dict[k] = state_dict[k].cpu() state_dict[k] = state_dict[k].cpu()
if self.mixed_precision: # In case we are in mixed precision, restore buffers back to buffer_dtype.
# In case we are in mixed precision, restore buffers back to buffer_dtype. maybe_cast_buffers()
self._cast_buffers()
return state_dict return state_dict
@typing.overload @typing.overload
...@@ -860,7 +863,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -860,7 +863,7 @@ class FullyShardedDataParallel(nn.Module):
def _lazy_init(self) -> None: def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right """Initialization steps that should happen lazily, typically right
before the first forward pass. before the first forward pass.
""" """
# Initialize param attributes lazily, in case the param's dtype or # Initialize param attributes lazily, in case the param's dtype or
# device changes after __init__. # device changes after __init__.
......
...@@ -101,8 +101,9 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -101,8 +101,9 @@ class TestSaveLoadStateDict(DistributedTest):
def _test_state_dict_before_forward(cls, config, rank, group): def _test_state_dict_before_forward(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
sd = ddp_model.state_dict() sd = ddp_model.state_dict()
wt = sd["embed_tokens.weight"] for param_name in ("embed_tokens.weight", "vocab_bias"):
assert wt.dtype == torch.float32, f"got dtype {wt.dtype} expected torch.float32" wt = sd[param_name]
assert wt.dtype == torch.float32, f"got dtype {wt.dtype} for {param_name}, expected torch.float32"
cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision) cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision)
@classmethod @classmethod
...@@ -232,16 +233,10 @@ class TestStateDictDeviceDtype(DistributedTest): ...@@ -232,16 +233,10 @@ class TestStateDictDeviceDtype(DistributedTest):
assert v.device.type == "cuda", v.device.type assert v.device.type == "cuda", v.device.type
expected_dtype = torch.float16 if pure_fp16 else torch.float32 expected_dtype = torch.float16 if pure_fp16 else torch.float32
buffers = {
k.replace("_fsdp_wrapped_module.", "").replace("_fpw_module.", "") for k, _ in fsdp_model.named_buffers()
}
for k, v in sd.items(): for k, v in sd.items():
if not torch.is_floating_point(v): if not torch.is_floating_point(v):
continue continue
if k in buffers: assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
assert v.dtype == fsdp_model.buffer_dtype, f"{v.dtype} != {fsdp_model.buffer_dtype}"
else:
assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
@skip_if_cuda @skip_if_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment