Commit e7652751 authored by David Yan's avatar David Yan Committed by Facebook GitHub Bot
Browse files

Save and load model EMA state for sharded state dicts in FSDPCheckpointer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/519

Prior to this, FSDP checkpointer did not save EMA state which matched the model state when the model used sharded state dict. This diff adds this functionality.

Reviewed By: YanjunChen329

Differential Revision: D44270790

fbshipit-source-id: f522765ad56e8279f355c43a19f26c3b6bcf01e3
parent 67267821
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import copy
import json
import os
from typing import Callable, cast, IO
......@@ -106,7 +107,6 @@ class FSDPCheckpointer(QATCheckpointer):
)
ema_state = checkpoint.pop("ema_state")
scatter_ema_state_dict(ema_state, self.model)
# Convert local ckpt by resaving the current state
if convert_local_ckpt_to_global:
self.logger.info(
......@@ -243,15 +243,22 @@ def gather_ema_state_dict(ema_state, model: FSDPWrapper):
):
state = EMAState.FromModel(model)
return state.state
return ema_state.state_dict()
elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
with ema_state.apply_and_restore(model):
# must deepcopy the state dict, else we return a reference to the model state
return dict(copy.deepcopy(model.state_dict()))
else:
return ema_state.state_dict()
def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper):
"""
Load a full/local EMA state dict to a FSDP model.
Load a full/sharded/local EMA state dict to a FSDP model.
If loading full state dict, ema_state_dict needs to be properly sharded for each FSDP process to store locally
Note that, at load-time, model.state_dict_type is automatically set to the type of the state dict being loaded
by accessing metadata, so there's no possibility of a save-load mismatch
"""
if model.state_dict_type == StateDictType.FULL_STATE_DICT:
if model.load_state_dict_type == StateDictType.FULL_STATE_DICT:
# Store the current model state.
old_local_state = EMAState.FromModel(model)
......@@ -271,5 +278,17 @@ def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper):
model.ema_state.save_from(model)
# Restore the old model state
old_local_state.apply_to(model)
elif model.load_state_dict_type == StateDictType.SHARDED_STATE_DICT:
# Store current model state temporarily
old_state = EMAState.FromModel(model)
# Load the ema state dict into the model
model.load_state_dict(ema_state_dict)
# Save ema state with correct FQNs via EMAState.save_from
model.ema_state.save_from(model)
# restore old model state
old_state.apply_to(model)
else:
model.ema_state.load_state_dict(ema_state_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment