"vscode:/vscode.git/clone" did not exist on "b3d10d6d65a80593627c6738fbeded2f69b5129f"
Commit 2d4d2f29 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

resolve CPU OOM with FSDP checkpointer

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/632

Reviewed By: yzhao30

Differential Revision: D50663689

fbshipit-source-id: 5c4c1dd2e5d2087be5aec268672bb5e7fc329df9
parent 7ace1ef0
......@@ -220,12 +220,15 @@ class FSDPCheckpointer(QATCheckpointer):
self.logger.info("Finished saving checkpoint to {}".format(filename))
def _load_file(self, f: str):
with (
interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher())
if isinstance(self.model, FSDPWrapper)
and self.model.state_dict_type != StateDictType.FULL_STATE_DICT
else nullcontext() # FULL_STATE_DICT doesn't need interleaving
):
if isinstance(self.model, FSDPWrapper):
with (
interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher())
if self.model.state_dict_type != StateDictType.FULL_STATE_DICT
else nullcontext() # FULL_STATE_DICT doesn't need interleaving
):
# use mmap for FSDP checkpoints
return torch.load(f, map_location=torch.device("cpu"), mmap=True)
else:
return super()._load_file(f)
def _save_metadata(self, path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment