Unverified Commit 25cebf85 authored by Pete's avatar Pete Committed by GitHub
Browse files

Fix buffer dtype in ` FSDP.state_dict()` when using mixed precision (#705)

* add failing test for buffer dtype

* fix buffer dtype issue

* update CHANGELOG

* fix
parent 3443a635
......@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
### Fixed
- checkpointing: use dummy tensor to ensure backward pass is called [#701]
- FSDP: fixed bug where buffers returned in `state_dict()` could still be half precision when `mixed_precision` is set to `True`.
### Added
......
......@@ -653,17 +653,21 @@ class FullyShardedDataParallel(nn.Module):
if torch.cuda.is_available():
torch.cuda.synchronize()
self._lazy_init()
def maybe_cast_buffers(dtype: Optional[torch.dtype] = None) -> None:
if self.mixed_precision:
# Buffers dtype stays consistent with parameters.
self._cast_buffers(dtype=torch.float32)
self._cast_buffers(dtype=dtype)
if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
with self.summon_full_params(recurse=False, volatile=True):
maybe_cast_buffers(torch.float32)
state_dict = super().state_dict(*args, **kwargs)
else:
maybe_cast_buffers(torch.float32)
state_dict = super().state_dict(*args, **kwargs)
else:
maybe_cast_buffers(torch.float32)
if self.flatten_parameters:
assert isinstance(self.module, FlattenParamsWrapper)
state_dict = self.module.flat_state_dict(*args, **kwargs)
......@@ -674,9 +678,8 @@ class FullyShardedDataParallel(nn.Module):
for k in state_dict.keys():
state_dict[k] = state_dict[k].cpu()
if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to buffer_dtype.
self._cast_buffers()
maybe_cast_buffers()
return state_dict
@typing.overload
......
......@@ -101,8 +101,9 @@ class TestSaveLoadStateDict(DistributedTest):
def _test_state_dict_before_forward(cls, config, rank, group):
ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config)
sd = ddp_model.state_dict()
wt = sd["embed_tokens.weight"]
assert wt.dtype == torch.float32, f"got dtype {wt.dtype} expected torch.float32"
for param_name in ("embed_tokens.weight", "vocab_bias"):
wt = sd[param_name]
assert wt.dtype == torch.float32, f"got dtype {wt.dtype} for {param_name}, expected torch.float32"
cls._train_for_several_steps(ddp_model, 1, ddp_model.mixed_precision)
@classmethod
......@@ -232,15 +233,9 @@ class TestStateDictDeviceDtype(DistributedTest):
assert v.device.type == "cuda", v.device.type
expected_dtype = torch.float16 if pure_fp16 else torch.float32
buffers = {
k.replace("_fsdp_wrapped_module.", "").replace("_fpw_module.", "") for k, _ in fsdp_model.named_buffers()
}
for k, v in sd.items():
if not torch.is_floating_point(v):
continue
if k in buffers:
assert v.dtype == fsdp_model.buffer_dtype, f"{v.dtype} != {fsdp_model.buffer_dtype}"
else:
assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment