"...text-generation-inference.git" did not exist on "5d27f5259b710c5d77925930bd804edcaa3821a6"
Commit 477629d0 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Make EMA checkpointing with FSDP more robust

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/615

Previous FSDP EMA checkpointing logic directly handles `EMAState`: it manually calls `FSDP.summon_full_params()` to gather the full model params, and reconstruct/load an `EMAState` for checkpointing. This logic has two drawbacks:

1. `FSDP.summon_full_params()` gathers all model weights at the same time, which could cause OOM issues if the model can't fit into a single GPU. This is quite common for FSDP workloads.
2.  Directly saving and loading `EMAState` is error-prone. EMA state dict has different semantics and behaviors than `model.state_dict()`. However, users often expect it to function seamlessly like the model state dict

This diff modifies the save/load logic of EMA to directly use `model.state_dict()` to solve the above 2 painpoints

Reviewed By: wat3rBro

Differential Revision: D48813697

fbshipit-source-id: be53c2677d2e493ba923508bbd82d9d295397941
parent c668ed4e
import copy import copy
import logging
from d2go.modeling.ema import EMAState from d2go.modeling.ema import EMAState
from d2go.trainer.fsdp import FSDPWrapper from d2go.trainer.fsdp import FSDPWrapper
from torch.distributed.fsdp.fully_sharded_data_parallel import ( from torch.distributed.fsdp.fully_sharded_data_parallel import (
...@@ -7,6 +9,8 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( ...@@ -7,6 +9,8 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
StateDictType, StateDictType,
) )
logger = logging.getLogger(__name__)
def gather_optimizer_state_dict(optimizer, model: FSDPWrapper): def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
""" """
...@@ -43,22 +47,10 @@ def gather_ema_state_dict(ema_state: EMAState, model: FSDPWrapper): ...@@ -43,22 +47,10 @@ def gather_ema_state_dict(ema_state: EMAState, model: FSDPWrapper):
Get full/local EMA state dict from an FSDP model. Get full/local EMA state dict from an FSDP model.
If using full state dict, gather local sharded EMA states from all FSDP processes and aggregate them into a full EMA state dict If using full state dict, gather local sharded EMA states from all FSDP processes and aggregate them into a full EMA state dict
""" """
if model.state_dict_type == StateDictType.FULL_STATE_DICT: if model.state_dict_type in [
# Apply local ema states to the model and unshard them StateDictType.FULL_STATE_DICT,
with ema_state.apply_and_restore(model): StateDictType.SHARDED_STATE_DICT,
with FSDP.summon_full_params( ]:
model,
writeback=False,
offload_to_cpu=model.offload_to_cpu,
rank0_only=model.rank0_only,
):
state = EMAState.FromModel(
model,
include_frozen=ema_state.include_frozen,
include_buffer=ema_state.include_buffer,
)
return state.state
elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
with ema_state.apply_and_restore(model): with ema_state.apply_and_restore(model):
# must deepcopy the state dict, else we return a reference to the model state # must deepcopy the state dict, else we return a reference to the model state
return dict(copy.deepcopy(model.state_dict())) return dict(copy.deepcopy(model.state_dict()))
...@@ -73,32 +65,19 @@ def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper): ...@@ -73,32 +65,19 @@ def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper):
Note that, at load-time, model.state_dict_type is automatically set to the type of the state dict being loaded Note that, at load-time, model.state_dict_type is automatically set to the type of the state dict being loaded
by accessing metadata, so there's no possibility of a save-load mismatch by accessing metadata, so there's no possibility of a save-load mismatch
""" """
if model.load_state_dict_type == StateDictType.FULL_STATE_DICT: if model.load_state_dict_type in [
# Store the current model state. StateDictType.FULL_STATE_DICT,
old_local_state = EMAState.FromModel(model) StateDictType.SHARDED_STATE_DICT,
]:
# Apply ema_state as a FULL state dict to the model so it can be properly sharded
# Currently only [offload_to_cpu=False, rank0_only=False] is supported
with FSDP.summon_full_params(
model,
writeback=True,
offload_to_cpu=False,
rank0_only=False,
):
ema_state = EMAState()
ema_state.load_state_dict(ema_state_dict)
ema_state.apply_to(model)
# Load ema_state from model
model.ema_state.save_from(model)
# Restore the old model state
old_local_state.apply_to(model)
elif model.load_state_dict_type == StateDictType.SHARDED_STATE_DICT:
# Store current model state temporarily # Store current model state temporarily
old_state = EMAState.FromModel(model) old_state = EMAState.FromModel(model)
# Load the ema state dict into the model # Load the ema state dict into the model
model.load_state_dict(ema_state_dict) m, u = model.load_state_dict(ema_state_dict, strict=False)
if len(m) > 0:
logger.info("Missing keys while loading EMA:", m)
if len(u) > 0:
logger.info("Unexpected keys while loading EMA:", u)
# Save ema state with correct FQNs via EMAState.save_from # Save ema state with correct FQNs via EMAState.save_from
model.ema_state.save_from(model) model.ema_state.save_from(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment