Unverified Commit 5ed53a3e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] handle EMA in the state_dict (#1044)

* [fix] handle EMA in the state_dict

* better fix
parent 4cb293e8
...@@ -275,7 +275,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -275,7 +275,7 @@ class FullyShardedDataParallel(nn.Module):
Default: False Default: False
state_dict_device (torch.device, Optional): state_dict_device (torch.device, Optional):
device for parameters returned by :func:`state_dict`. If not given, device for parameters returned by :func:`state_dict`. If not given,
this will default to ``compute_dtype``. Note that only the device this will default to ``compute_device``. Note that only the device
type will be respected (e.g., "cuda:0" and "cuda:1" are the same). type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
clear_autocast_cache (bool): clear_autocast_cache (bool):
When using mixed precision training with `torch.amp.autocast`, if the model weights When using mixed precision training with `torch.amp.autocast`, if the model weights
...@@ -2532,23 +2532,49 @@ def _post_state_dict_hook( ...@@ -2532,23 +2532,49 @@ def _post_state_dict_hook(
if state_dict_on_rank_0_only and dist.get_rank() != 0: if state_dict_on_rank_0_only and dist.get_rank() != 0:
state_dict.clear() state_dict.clear()
return state_dict return state_dict
# Assuming we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context def apply_to_tensor(obj: torch.Tensor) -> torch.Tensor:
# exits. At the same time, this hook can be called multiple times """Apply needed operations on a tensor."""
# recursively, so we need to make sure that we only clone each tensor at assert isinstance(obj, torch.Tensor), f"Expect a tensor, got {type(obj)}"
# most once. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed. # Already applied?
if getattr(obj, "_has_been_cloned", False):
return obj
if obj.device.type != module.state_dict_device.type:
# Move to right device. This is often used to save GPU memory.
obj = obj.to(device=module.state_dict_device)
elif module.training_state == TrainingState.SUMMON_FULL_PARAMS:
# If we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# most once. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
#
# "elif" because .to() clones the object too.
obj = obj.clone()
# Both .to() and .clone() copies a new object. So we set this flag.
obj._has_been_cloned = True
return obj
# State_dict is supposed to be a flat dict (not nested). The
# keys are encoded with hierarchy. Therefore, we can loop
# over the dict here. (See else case below for additional notes.)
for key in state_dict.keys(): for key in state_dict.keys():
if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False): # Skip keys without right prefix.
if not key.startswith(prefix):
continue continue
if state_dict[key].device.type != module.state_dict_device.type: elif isinstance(state_dict[key], torch.Tensor):
state_dict[key] = state_dict[key].to(device=module.state_dict_device) state_dict[key] = apply_to_tensor(state_dict[key])
state_dict[key]._has_been_cloned = True else:
elif module.training_state == TrainingState.SUMMON_FULL_PARAMS: # For example, EMA module from data2vec is a dict of tensors.
# We copy the state_dict since full param will be freed after we logging.warning(
# exit the ``summon_full_params()`` context. f"Got an unexpected data type in state_dict" f"key={key} value_type={type(state_dict[key])}"
state_dict[key] = state_dict[key].clone() )
state_dict[key]._has_been_cloned = True
# Remove "_fsdp_wrapped_module." prefix # Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix) replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment