Unverified Commit 5ed53a3e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] handle EMA in the state_dict (#1044)

* [fix] handle EMA in the state_dict

* better fix
parent 4cb293e8
......@@ -275,7 +275,7 @@ class FullyShardedDataParallel(nn.Module):
Default: False
state_dict_device (torch.device, Optional):
device for parameters returned by :func:`state_dict`. If not given,
this will default to ``compute_dtype``. Note that only the device
this will default to ``compute_device``. Note that only the device
type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
clear_autocast_cache (bool):
When using mixed precision training with `torch.amp.autocast`, if the model weights
......@@ -2532,23 +2532,49 @@ def _post_state_dict_hook(
if state_dict_on_rank_0_only and dist.get_rank() != 0:
state_dict.clear()
return state_dict
# Assuming we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# most once. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
def apply_to_tensor(obj: torch.Tensor) -> torch.Tensor:
"""Apply needed operations on a tensor."""
assert isinstance(obj, torch.Tensor), f"Expect a tensor, got {type(obj)}"
# Already applied?
if getattr(obj, "_has_been_cloned", False):
return obj
if obj.device.type != module.state_dict_device.type:
# Move to right device. This is often used to save GPU memory.
obj = obj.to(device=module.state_dict_device)
elif module.training_state == TrainingState.SUMMON_FULL_PARAMS:
# If we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# most once. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
#
# "elif" because .to() clones the object too.
obj = obj.clone()
# Both .to() and .clone() copies a new object. So we set this flag.
obj._has_been_cloned = True
return obj
# State_dict is supposed to be a flat dict (not nested). The
# keys are encoded with hierarchy. Therefore, we can loop
# over the dict here. (See else case below for additional notes.)
for key in state_dict.keys():
if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False):
# Skip keys without right prefix.
if not key.startswith(prefix):
continue
if state_dict[key].device.type != module.state_dict_device.type:
state_dict[key] = state_dict[key].to(device=module.state_dict_device)
state_dict[key]._has_been_cloned = True
elif module.training_state == TrainingState.SUMMON_FULL_PARAMS:
# We copy the state_dict since full param will be freed after we
# exit the ``summon_full_params()`` context.
state_dict[key] = state_dict[key].clone()
state_dict[key]._has_been_cloned = True
elif isinstance(state_dict[key], torch.Tensor):
state_dict[key] = apply_to_tensor(state_dict[key])
else:
# For example, EMA module from data2vec is a dict of tensors.
logging.warning(
f"Got an unexpected data type in state_dict" f"key={key} value_type={type(state_dict[key])}"
)
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment