Commit 46606a02 authored by David Yan's avatar David Yan Committed by Facebook GitHub Bot
Browse files

Add tests for sharded_state_dict and fix compatibility problems

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/511

Add tests for sharded_state_dict integration in AIF Checkpointer

Fix compatibility problems including:
1. small API errors of flatten_sharded_optim_state_dict
2. deprecate model.use_local_state_dict and model.load_local_state_dict
3. fix auto conversion for local_state_dict
4. fix T148056077: add metadata to differentiate between local_state_dict and sharded_state_dict when loading a directory with FSDPCheckpointer

Reviewed By: YanjunChen329

Differential Revision: D44160045

fbshipit-source-id: f607b7076d0e49b9407f9adfbc8ecfe439c3b0c9
parent fbc1c2e8
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import json
import os
from typing import Callable, cast, IO
......@@ -49,25 +50,32 @@ class FSDPCheckpointer(QATCheckpointer):
load_path = path
if path:
# loading path is a directory: local or sharded state dict is used
# TODO(T148056077): Make loading sharded state dicts more elegant
if self.path_manager.isdir(path):
# Get state dict type from metadata file
metadata = self._load_metadata(path)
state_dict_type = (
metadata["state_dict_type"] if metadata else "LOCAL_STATE_DICT"
)
assert state_dict_type in ["LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
type_str = "local" if "LOCAL_STATE_DICT" else "sharded"
self.logger.info(
"[FSDPCheckpointer] Loading from local or sharded checkpoint ..."
f"[FSDPCheckpointer] Loading from {type_str} checkpoint ..."
)
self.model.load_local_state_dict = True
self.model.load_state_dict_type = StateDictType[state_dict_type]
load_path = os.path.join(path, f"rank{comm.get_rank()}.pth")
# loading path is a file: full global state dict is used
else:
self.logger.info(
"[FSDPCheckpointer] Loading from global checkpoint ..."
"[FSDPCheckpointer] Loading from full checkpoint ..."
)
self.model.load_local_state_dict = False
self.model.load_state_dict_type = StateDictType.FULL_STATE_DICT
# Convert local ckpt to global ckpt when we load from a local ckpt but want to save to global ckpt
convert_local_ckpt_to_global = (
path
and self.model.load_local_state_dict
and not self.model.use_local_state_dict
and self.model.load_state_dict_type == StateDictType.LOCAL_STATE_DICT
and self.model.state_dict_type == StateDictType.FULL_STATE_DICT
)
# Load all checkpointables from local ckpt if we want to convert to global ckpt
......@@ -156,8 +164,10 @@ class FSDPCheckpointer(QATCheckpointer):
self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes
comm.synchronize()
if comm.is_main_process() and tag_last_ckpt:
self.tag_last_checkpoint(name)
if comm.is_main_process():
self._save_metadata(new_save_dir)
if tag_last_ckpt:
self.tag_last_checkpoint(name)
elif comm.is_main_process():
basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
......@@ -176,6 +186,20 @@ class FSDPCheckpointer(QATCheckpointer):
with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()):
return super()._load_file(f)
def _save_metadata(self, path):
metadata_file = os.path.join(path, "metadata.json")
obj = {"state_dict_type": self.model.state_dict_type.name}
with self.path_manager.open(metadata_file, "w") as f:
json.dump(obj, f)
def _load_metadata(self, path):
metadata_file = os.path.join(path, "metadata.json")
if self.path_manager.exists(metadata_file):
with self.path_manager.open(metadata_file, "r") as f:
return json.load(f)
else:
return None
def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
"""
......@@ -198,7 +222,7 @@ def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper
optim_state_dict = FSDP.shard_full_optim_state_dict(optim_state_dict, model)
elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
optim_state_dict = FSDP.flatten_sharded_optim_state_dict(
optim_state_dict, model
optim_state_dict, model, optimizer
)
optimizer.load_state_dict(optim_state_dict)
......
......@@ -113,15 +113,11 @@ class FSDPWrapper(FSDP):
state_dict_type: StateDictType,
load_state_dict_type: StateDictType,
amp_autocast_dtype: Optional[torch.dtype] = None,
use_local_state_dict: bool = False,
load_local_state_dict: bool = False,
state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True,
**fsdp_kwargs,
):
self.precision = amp_autocast_dtype
self.use_local_state_dict = use_local_state_dict
self.load_local_state_dict = load_local_state_dict
self.state_dict_type = state_dict_type
self.load_state_dict_type = load_state_dict_type
self.offload_to_cpu = state_dict_cpu_offload
......@@ -181,6 +177,7 @@ def build_fsdp(
reduce_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
amp_autocast_dtype: Optional[torch.dtype] = None,
# TODO: to remove after migration to state_dict_type completes
use_local_state_dict: bool = False,
load_local_state_dict: bool = False,
state_dict_type: Optional[StateDictType] = None,
......@@ -258,8 +255,6 @@ def build_fsdp(
)
wrapper_kwargs = {
"amp_autocast_dtype": amp_autocast_dtype,
"use_local_state_dict": use_local_state_dict,
"load_local_state_dict": load_local_state_dict,
"state_dict_type": _state_dict_type,
"load_state_dict_type": _load_state_dict_type,
"state_dict_cpu_offload": state_dict_cpu_offload,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment