Commit 206a05c6 authored by Ed Pizzi's avatar Ed Pizzi Committed by Facebook GitHub Bot
Browse files

Propagate include_frozen/buffers to EMAState in FSDP FULL_STATE_DICT checkpoints

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/620

EMA can be configured to exclude frozen (`requires_grad=False`) parameters and buffers, reducing memory use and checkpoint size.

However `FULL_STATE_DICT` FSDP + EMA checkpoints construct an inner `EMAState` after unsharding FSDP parameters. This inner `EMAState` uses default `include_frozen` and `include_buffers` settings, resulting in checkpoints containing frozen parameters and buffers regardless of settings.

Propagate `include_frozen` and `include_buffers` settings to the inner `EMAState` when gathering `FULL_STATE_DICT` FSDP EMA state.

This change only affects frozen parameters with a parallel fix to PyTorch FSDP to propagate `requires_grad` across parameter sharding/unsharding: https://github.com/pytorch/pytorch/pull/109892.

Reviewed By: daveboat

Differential Revision: D49517178

fbshipit-source-id: 0fe159dcec9ec1f2c456ae2ee7798681e7536249
parent 93037c4e
......@@ -38,7 +38,7 @@ def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper
optimizer.load_state_dict(optim_state_dict)
def gather_ema_state_dict(ema_state, model: FSDPWrapper):
def gather_ema_state_dict(ema_state: EMAState, model: FSDPWrapper):
"""
Get full/local EMA state dict from an FSDP model.
If using full state dict, gather local sharded EMA states from all FSDP processes and aggregate them into a full EMA state dict
......@@ -52,7 +52,11 @@ def gather_ema_state_dict(ema_state, model: FSDPWrapper):
offload_to_cpu=model.offload_to_cpu,
rank0_only=model.rank0_only,
):
state = EMAState.FromModel(model)
state = EMAState.FromModel(
model,
include_frozen=ema_state.include_frozen,
include_buffer=ema_state.include_buffer,
)
return state.state
elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
with ema_state.apply_and_restore(model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment