Commit 5ad2d57e authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Convert local checkpoint to global one automatically in d2go FSDP checkpointer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/446

## Design
Following D41861308, local checkpoints need to be converted to global ones before  being loaded and used in non-FSDP wrapped models. This diff implements such conversion in d2go checkpointer level to allow automatic conversion with minimal user interference and no new config key.

In previous diff, `FSDPWrapper` has 2 loading modes and 2 saving modes: it uses `load_local_state_dict` to determine whether the ckpt we want to load is local or global, and uses `use_local_state_dict` to decide whether to save new ckpts as local or global. Thus, there are 4 combinations of loading/saving modes:
1. load local + save local
2. load local + save global
3. load global + save local
4. load global + save global

And the local-to-global checkpoint conversion maps to mode 2: load local + save global. Thus, when the checkpointer is in mode 2, it automatically saves the model to a global ckpt right after it loads the local ckpt. Because this happens in checkpointer level, normal training/eval can resume after ckpt conversion. This gives users a consistent and seamless experience with normal training/eval, while also providing a separate ckpt conversion feature via eval-only.

## Usage
Suppose we want to convert local checkpoint `/tmp/model_final`, user can run the same training command with extra args: `MODEL.WEIGHTS=/tmp/model_final` and `FSDP.USE_LOCAL_STATE_DICT=False`

Wiki: https://www.internalfb.com/intern/wiki/Mobile_Vision/Detectron2Go/D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go_Tutorials/Diffusion_Pipeline/Diffusion_Model_Inference/#using-checkpoints-traine

Reviewed By: wat3rBro

Differential Revision: D41926662

fbshipit-source-id: 18a62607a79b0e917d929e9ea85ac1658fb895ca
parent eea6339f
...@@ -28,17 +28,33 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -28,17 +28,33 @@ class FSDPCheckpointer(QATCheckpointer):
In general users should not resume non-FSDP training with FSDP. In general users should not resume non-FSDP training with FSDP.
""" """
if isinstance(self.model, FSDPWrapper): if isinstance(self.model, FSDPWrapper):
if path is not None: load_path = path
if path:
# loading path is a directory: sharded local state dict is used # loading path is a directory: sharded local state dict is used
if self.path_manager.isdir(path): if self.path_manager.isdir(path):
self.logger.info(
"[FSDPCheckpointer] Loading from local checkpoint ..."
)
self.model.load_local_state_dict = True self.model.load_local_state_dict = True
path = os.path.join(path, f"rank{comm.get_rank()}.pth") load_path = os.path.join(path, f"rank{comm.get_rank()}.pth")
# loading path is a file: full global state dict is used # loading path is a file: full global state dict is used
else: else:
self.logger.info(
"[FSDPCheckpointer] Loading from global checkpoint ..."
)
self.model.load_local_state_dict = False self.model.load_local_state_dict = False
# Convert local ckpt to global ckpt when we load from a local ckpt but want to save to global ckpt
convert_local_ckpt_to_global = (
path
and self.model.load_local_state_dict
and not self.model.use_local_state_dict
)
# Load all checkpointables from local ckpt if we want to convert to global ckpt
checkpointables_iter = ( checkpointables_iter = (
self.checkpointables.keys() self.checkpointables.keys()
if checkpointables is None if checkpointables is None or convert_local_ckpt_to_global
else checkpointables else checkpointables
) )
checkpointables_filtered = [ checkpointables_filtered = [
...@@ -47,22 +63,39 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -47,22 +63,39 @@ class FSDPCheckpointer(QATCheckpointer):
if name not in ["optimizer", "ema_state"] if name not in ["optimizer", "ema_state"]
] ]
checkpoint = super().load(path, checkpointables=checkpointables_filtered) checkpoint = super().load(
load_path, checkpointables=checkpointables_filtered
)
if "optimizer" in checkpointables_iter: if "optimizer" in checkpointables_iter:
self.logger.info("Loading optimizer from {} ...".format(path)) self.logger.info(
f"[FSDPCheckpointer] Loading optimizer from {load_path} ..."
)
optimizer = self.checkpointables["optimizer"] optimizer = self.checkpointables["optimizer"]
osd = checkpoint.pop("optimizer") osd = checkpoint.pop("optimizer")
scatter_optimizer_state_dict(optimizer, osd, self.model) scatter_optimizer_state_dict(optimizer, osd, self.model)
if "ema_state" in checkpointables_iter: if "ema_state" in checkpointables_iter:
self.logger.info("Loading ema_state from {} ...".format(path)) self.logger.info(
f"[FSDPCheckpointer] Loading ema_state from {load_path} ..."
)
ema_state = checkpoint.pop("ema_state") ema_state = checkpoint.pop("ema_state")
scatter_ema_state_dict(ema_state, self.model) scatter_ema_state_dict(ema_state, self.model)
# Convert local ckpt by resaving the current state
if convert_local_ckpt_to_global:
self.logger.info(
"[FSDPCheckpointer] Converting local FSDP checkpoint to global checkpoint ..."
)
self.save(os.path.basename(path), tag_last_ckpt=False, **checkpoint)
self.logger.info(
"[FSDPCheckpointer] Local-to-global checkpoint conversion finishes"
)
# return all remaining checkpoints # return all remaining checkpoints
return checkpoint return checkpoint
else: else:
return super().load(path, checkpointables=checkpointables) return super().load(path, checkpointables=checkpointables)
def save(self, name: str, **kwargs) -> None: def save(self, name: str, tag_last_ckpt=True, **kwargs) -> None:
""" """
Add support for saving sharding models and optimizers. Add support for saving sharding models and optimizers.
The rest of the code is copied from implementation in the superclass The rest of the code is copied from implementation in the superclass
...@@ -101,14 +134,15 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -101,14 +134,15 @@ class FSDPCheckpointer(QATCheckpointer):
self._save_file(data, save_file) self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes # Main process tags last checkpoint if no errors in all processes
comm.synchronize() comm.synchronize()
if comm.is_main_process(): if comm.is_main_process() and tag_last_ckpt:
self.tag_last_checkpoint(name) self.tag_last_checkpoint(name)
elif comm.is_main_process(): elif comm.is_main_process():
basename = "{}.pth".format(name) basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename) save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename assert os.path.basename(save_file) == basename, basename
self._save_file(data, save_file) self._save_file(data, save_file)
self.tag_last_checkpoint(basename) if tag_last_ckpt:
self.tag_last_checkpoint(basename)
def _save_file(self, data, filename): def _save_file(self, data, filename):
self.logger.info("Saving checkpoint to {}".format(filename)) self.logger.info("Saving checkpoint to {}".format(filename))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment