Commit 5ad2d57e authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Convert local checkpoint to global one automatically in d2go FSDP checkpointer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/446

## Design
Following D41861308, local checkpoints need to be converted to global ones before  being loaded and used in non-FSDP wrapped models. This diff implements such conversion in d2go checkpointer level to allow automatic conversion with minimal user interference and no new config key.

In previous diff, `FSDPWrapper` has 2 loading modes and 2 saving modes: it uses `load_local_state_dict` to determine whether the ckpt we want to load is local or global, and uses `use_local_state_dict` to decide whether to save new ckpts as local or global. Thus, there are 4 combinations of loading/saving modes:
1. load local + save local
2. load local + save global
3. load global + save local
4. load global + save global

And the local-to-global checkpoint conversion maps to mode 2: load local + save global. Thus, when the checkpointer is in mode 2, it automatically saves the model to a global ckpt right after it loads the local ckpt. Because this happens in checkpointer level, normal training/eval can resume after ckpt conversion. This gives users a consistent and seamless experience with normal training/eval, while also providing a separate ckpt conversion feature via eval-only.

## Usage
Suppose we want to convert local checkpoint `/tmp/model_final`, user can run the same training command with extra args: `MODEL.WEIGHTS=/tmp/model_final` and `FSDP.USE_LOCAL_STATE_DICT=False`

Wiki: https://www.internalfb.com/intern/wiki/Mobile_Vision/Detectron2Go/D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go_Tutorials/Diffusion_Pipeline/Diffusion_Model_Inference/#using-checkpoints-traine

Reviewed By: wat3rBro

Differential Revision: D41926662

fbshipit-source-id: 18a62607a79b0e917d929e9ea85ac1658fb895ca
parent eea6339f
......@@ -28,17 +28,33 @@ class FSDPCheckpointer(QATCheckpointer):
In general users should not resume non-FSDP training with FSDP.
"""
if isinstance(self.model, FSDPWrapper):
if path is not None:
load_path = path
if path:
# loading path is a directory: sharded local state dict is used
if self.path_manager.isdir(path):
self.logger.info(
"[FSDPCheckpointer] Loading from local checkpoint ..."
)
self.model.load_local_state_dict = True
path = os.path.join(path, f"rank{comm.get_rank()}.pth")
load_path = os.path.join(path, f"rank{comm.get_rank()}.pth")
# loading path is a file: full global state dict is used
else:
self.logger.info(
"[FSDPCheckpointer] Loading from global checkpoint ..."
)
self.model.load_local_state_dict = False
# Convert local ckpt to global ckpt when we load from a local ckpt but want to save to global ckpt
convert_local_ckpt_to_global = (
path
and self.model.load_local_state_dict
and not self.model.use_local_state_dict
)
# Load all checkpointables from local ckpt if we want to convert to global ckpt
checkpointables_iter = (
self.checkpointables.keys()
if checkpointables is None
if checkpointables is None or convert_local_ckpt_to_global
else checkpointables
)
checkpointables_filtered = [
......@@ -47,22 +63,39 @@ class FSDPCheckpointer(QATCheckpointer):
if name not in ["optimizer", "ema_state"]
]
checkpoint = super().load(path, checkpointables=checkpointables_filtered)
checkpoint = super().load(
load_path, checkpointables=checkpointables_filtered
)
if "optimizer" in checkpointables_iter:
self.logger.info("Loading optimizer from {} ...".format(path))
self.logger.info(
f"[FSDPCheckpointer] Loading optimizer from {load_path} ..."
)
optimizer = self.checkpointables["optimizer"]
osd = checkpoint.pop("optimizer")
scatter_optimizer_state_dict(optimizer, osd, self.model)
if "ema_state" in checkpointables_iter:
self.logger.info("Loading ema_state from {} ...".format(path))
self.logger.info(
f"[FSDPCheckpointer] Loading ema_state from {load_path} ..."
)
ema_state = checkpoint.pop("ema_state")
scatter_ema_state_dict(ema_state, self.model)
# Convert local ckpt by resaving the current state
if convert_local_ckpt_to_global:
self.logger.info(
"[FSDPCheckpointer] Converting local FSDP checkpoint to global checkpoint ..."
)
self.save(os.path.basename(path), tag_last_ckpt=False, **checkpoint)
self.logger.info(
"[FSDPCheckpointer] Local-to-global checkpoint conversion finishes"
)
# return all remaining checkpoints
return checkpoint
else:
return super().load(path, checkpointables=checkpointables)
def save(self, name: str, **kwargs) -> None:
def save(self, name: str, tag_last_ckpt=True, **kwargs) -> None:
"""
Add support for saving sharding models and optimizers.
The rest of the code is copied from implementation in the superclass
......@@ -101,14 +134,15 @@ class FSDPCheckpointer(QATCheckpointer):
self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes
comm.synchronize()
if comm.is_main_process():
if comm.is_main_process() and tag_last_ckpt:
self.tag_last_checkpoint(name)
elif comm.is_main_process():
basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self._save_file(data, save_file)
self.tag_last_checkpoint(basename)
if tag_last_ckpt:
self.tag_last_checkpoint(basename)
def _save_file(self, data, filename):
self.logger.info("Saving checkpoint to {}".format(filename))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment