Commit 46606a02 authored by David Yan's avatar David Yan Committed by Facebook GitHub Bot
Browse files

Add tests for sharded_state_dict and fix compatibility problems

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/511

Add tests for sharded_state_dict integration in AIF Checkpointer

Fix compatibility problems including:
1. small API errors of flatten_sharded_optim_state_dict
2. deprecate model.use_local_state_dict and model.load_local_state_dict
3. fix auto conversion for local_state_dict
4. fix T148056077: add metadata to differentiate between local_state_dict and sharded_state_dict when loading a directory with FSDPCheckpointer

Reviewed By: YanjunChen329

Differential Revision: D44160045

fbshipit-source-id: f607b7076d0e49b9407f9adfbc8ecfe439c3b0c9
parent fbc1c2e8
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import json
import os import os
from typing import Callable, cast, IO from typing import Callable, cast, IO
...@@ -49,25 +50,32 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -49,25 +50,32 @@ class FSDPCheckpointer(QATCheckpointer):
load_path = path load_path = path
if path: if path:
# loading path is a directory: local or sharded state dict is used # loading path is a directory: local or sharded state dict is used
# TODO(T148056077): Make loading sharded state dicts more elegant
if self.path_manager.isdir(path): if self.path_manager.isdir(path):
# Get state dict type from metadata file
metadata = self._load_metadata(path)
state_dict_type = (
metadata["state_dict_type"] if metadata else "LOCAL_STATE_DICT"
)
assert state_dict_type in ["LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
type_str = "local" if "LOCAL_STATE_DICT" else "sharded"
self.logger.info( self.logger.info(
"[FSDPCheckpointer] Loading from local or sharded checkpoint ..." f"[FSDPCheckpointer] Loading from {type_str} checkpoint ..."
) )
self.model.load_local_state_dict = True self.model.load_state_dict_type = StateDictType[state_dict_type]
load_path = os.path.join(path, f"rank{comm.get_rank()}.pth") load_path = os.path.join(path, f"rank{comm.get_rank()}.pth")
# loading path is a file: full global state dict is used # loading path is a file: full global state dict is used
else: else:
self.logger.info( self.logger.info(
"[FSDPCheckpointer] Loading from global checkpoint ..." "[FSDPCheckpointer] Loading from full checkpoint ..."
) )
self.model.load_local_state_dict = False self.model.load_state_dict_type = StateDictType.FULL_STATE_DICT
# Convert local ckpt to global ckpt when we load from a local ckpt but want to save to global ckpt # Convert local ckpt to global ckpt when we load from a local ckpt but want to save to global ckpt
convert_local_ckpt_to_global = ( convert_local_ckpt_to_global = (
path path
and self.model.load_local_state_dict and self.model.load_state_dict_type == StateDictType.LOCAL_STATE_DICT
and not self.model.use_local_state_dict and self.model.state_dict_type == StateDictType.FULL_STATE_DICT
) )
# Load all checkpointables from local ckpt if we want to convert to global ckpt # Load all checkpointables from local ckpt if we want to convert to global ckpt
...@@ -156,7 +164,9 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -156,7 +164,9 @@ class FSDPCheckpointer(QATCheckpointer):
self._save_file(data, save_file) self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes # Main process tags last checkpoint if no errors in all processes
comm.synchronize() comm.synchronize()
if comm.is_main_process() and tag_last_ckpt: if comm.is_main_process():
self._save_metadata(new_save_dir)
if tag_last_ckpt:
self.tag_last_checkpoint(name) self.tag_last_checkpoint(name)
elif comm.is_main_process(): elif comm.is_main_process():
basename = "{}.pth".format(name) basename = "{}.pth".format(name)
...@@ -176,6 +186,20 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -176,6 +186,20 @@ class FSDPCheckpointer(QATCheckpointer):
with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()): with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()):
return super()._load_file(f) return super()._load_file(f)
def _save_metadata(self, path):
metadata_file = os.path.join(path, "metadata.json")
obj = {"state_dict_type": self.model.state_dict_type.name}
with self.path_manager.open(metadata_file, "w") as f:
json.dump(obj, f)
def _load_metadata(self, path):
metadata_file = os.path.join(path, "metadata.json")
if self.path_manager.exists(metadata_file):
with self.path_manager.open(metadata_file, "r") as f:
return json.load(f)
else:
return None
def gather_optimizer_state_dict(optimizer, model: FSDPWrapper): def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
""" """
...@@ -198,7 +222,7 @@ def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper ...@@ -198,7 +222,7 @@ def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper
optim_state_dict = FSDP.shard_full_optim_state_dict(optim_state_dict, model) optim_state_dict = FSDP.shard_full_optim_state_dict(optim_state_dict, model)
elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT: elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
optim_state_dict = FSDP.flatten_sharded_optim_state_dict( optim_state_dict = FSDP.flatten_sharded_optim_state_dict(
optim_state_dict, model optim_state_dict, model, optimizer
) )
optimizer.load_state_dict(optim_state_dict) optimizer.load_state_dict(optim_state_dict)
......
...@@ -113,15 +113,11 @@ class FSDPWrapper(FSDP): ...@@ -113,15 +113,11 @@ class FSDPWrapper(FSDP):
state_dict_type: StateDictType, state_dict_type: StateDictType,
load_state_dict_type: StateDictType, load_state_dict_type: StateDictType,
amp_autocast_dtype: Optional[torch.dtype] = None, amp_autocast_dtype: Optional[torch.dtype] = None,
use_local_state_dict: bool = False,
load_local_state_dict: bool = False,
state_dict_cpu_offload: bool = True, state_dict_cpu_offload: bool = True,
state_dict_rank0_only: bool = True, state_dict_rank0_only: bool = True,
**fsdp_kwargs, **fsdp_kwargs,
): ):
self.precision = amp_autocast_dtype self.precision = amp_autocast_dtype
self.use_local_state_dict = use_local_state_dict
self.load_local_state_dict = load_local_state_dict
self.state_dict_type = state_dict_type self.state_dict_type = state_dict_type
self.load_state_dict_type = load_state_dict_type self.load_state_dict_type = load_state_dict_type
self.offload_to_cpu = state_dict_cpu_offload self.offload_to_cpu = state_dict_cpu_offload
...@@ -181,6 +177,7 @@ def build_fsdp( ...@@ -181,6 +177,7 @@ def build_fsdp(
reduce_dtype: Optional[torch.dtype] = None, reduce_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None, buffer_dtype: Optional[torch.dtype] = None,
amp_autocast_dtype: Optional[torch.dtype] = None, amp_autocast_dtype: Optional[torch.dtype] = None,
# TODO: to remove after migration to state_dict_type completes
use_local_state_dict: bool = False, use_local_state_dict: bool = False,
load_local_state_dict: bool = False, load_local_state_dict: bool = False,
state_dict_type: Optional[StateDictType] = None, state_dict_type: Optional[StateDictType] = None,
...@@ -258,8 +255,6 @@ def build_fsdp( ...@@ -258,8 +255,6 @@ def build_fsdp(
) )
wrapper_kwargs = { wrapper_kwargs = {
"amp_autocast_dtype": amp_autocast_dtype, "amp_autocast_dtype": amp_autocast_dtype,
"use_local_state_dict": use_local_state_dict,
"load_local_state_dict": load_local_state_dict,
"state_dict_type": _state_dict_type, "state_dict_type": _state_dict_type,
"load_state_dict_type": _load_state_dict_type, "load_state_dict_type": _load_state_dict_type,
"state_dict_cpu_offload": state_dict_cpu_offload, "state_dict_cpu_offload": state_dict_cpu_offload,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment