Commit e7652751 authored by David Yan's avatar David Yan Committed by Facebook GitHub Bot
Browse files

Save and load model EMA state for sharded state dicts in FSDPCheckpointer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/519

Prior to this, FSDP checkpointer did not save EMA state which matched the model state when the model used sharded state dict. This diff adds this functionality.

Reviewed By: YanjunChen329

Differential Revision: D44270790

fbshipit-source-id: f522765ad56e8279f355c43a19f26c3b6bcf01e3
parent 67267821
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import copy
import json import json
import os import os
from typing import Callable, cast, IO from typing import Callable, cast, IO
...@@ -106,7 +107,6 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -106,7 +107,6 @@ class FSDPCheckpointer(QATCheckpointer):
) )
ema_state = checkpoint.pop("ema_state") ema_state = checkpoint.pop("ema_state")
scatter_ema_state_dict(ema_state, self.model) scatter_ema_state_dict(ema_state, self.model)
# Convert local ckpt by resaving the current state # Convert local ckpt by resaving the current state
if convert_local_ckpt_to_global: if convert_local_ckpt_to_global:
self.logger.info( self.logger.info(
...@@ -243,15 +243,22 @@ def gather_ema_state_dict(ema_state, model: FSDPWrapper): ...@@ -243,15 +243,22 @@ def gather_ema_state_dict(ema_state, model: FSDPWrapper):
): ):
state = EMAState.FromModel(model) state = EMAState.FromModel(model)
return state.state return state.state
return ema_state.state_dict() elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
with ema_state.apply_and_restore(model):
# must deepcopy the state dict, else we return a reference to the model state
return dict(copy.deepcopy(model.state_dict()))
else:
return ema_state.state_dict()
def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper): def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper):
""" """
Load a full/local EMA state dict to a FSDP model. Load a full/sharded/local EMA state dict to a FSDP model.
If loading full state dict, ema_state_dict needs to be properly sharded for each FSDP process to store locally If loading full state dict, ema_state_dict needs to be properly sharded for each FSDP process to store locally
Note that, at load-time, model.state_dict_type is automatically set to the type of the state dict being loaded
by accessing metadata, so there's no possibility of a save-load mismatch
""" """
if model.state_dict_type == StateDictType.FULL_STATE_DICT: if model.load_state_dict_type == StateDictType.FULL_STATE_DICT:
# Store the current model state. # Store the current model state.
old_local_state = EMAState.FromModel(model) old_local_state = EMAState.FromModel(model)
...@@ -271,5 +278,17 @@ def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper): ...@@ -271,5 +278,17 @@ def scatter_ema_state_dict(ema_state_dict, model: FSDPWrapper):
model.ema_state.save_from(model) model.ema_state.save_from(model)
# Restore the old model state # Restore the old model state
old_local_state.apply_to(model) old_local_state.apply_to(model)
elif model.load_state_dict_type == StateDictType.SHARDED_STATE_DICT:
# Store current model state temporarily
old_state = EMAState.FromModel(model)
# Load the ema state dict into the model
model.load_state_dict(ema_state_dict)
# Save ema state with correct FQNs via EMAState.save_from
model.ema_state.save_from(model)
# restore old model state
old_state.apply_to(model)
else: else:
model.ema_state.load_state_dict(ema_state_dict) model.ema_state.load_state_dict(ema_state_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment