Commit 27918553 authored by SK Bong's avatar SK Bong Committed by Facebook GitHub Bot
Browse files

Add proper barriers around FSDP checkpointing

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/621

There should be barriers around FSDP checkpointing to ensure other ranks do not continue to training while rank 0 is still checkpointing

Also add log after checkpoint finishes

Reviewed By: wat3rBro

Differential Revision: D49541229

fbshipit-source-id: ac8c086eb0d65611be0b258e3006d9e14b7387ad
parent 477629d0
...@@ -188,19 +188,33 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -188,19 +188,33 @@ class FSDPCheckpointer(QATCheckpointer):
self._save_metadata(new_save_dir) self._save_metadata(new_save_dir)
if tag_last_ckpt: if tag_last_ckpt:
self.tag_last_checkpoint(name) self.tag_last_checkpoint(name)
elif comm.is_main_process(): else:
basename = "{}.pth".format(name) if comm.is_main_process():
save_file = os.path.join(self.save_dir, basename) basename = "{}.pth".format(name)
assert os.path.basename(save_file) == basename, basename save_file = os.path.join(self.save_dir, basename)
self._save_file(data, save_file) assert os.path.basename(save_file) == basename, basename
if tag_last_ckpt:
self.tag_last_checkpoint(basename) self.logger.info(
f"[FSDPCheckpointer] Rank {comm.get_rank()} is checkpointing {save_file}."
)
self._save_file(data, save_file)
if tag_last_ckpt:
self.tag_last_checkpoint(basename)
else:
self.logger.info(
f"[FSDPCheckpointer] Rank {comm.get_rank()} is deferring checkpointing to the main process."
)
comm.synchronize()
def _save_file(self, data, filename): def _save_file(self, data, filename):
self.logger.info("Saving checkpoint to {}".format(filename)) self.logger.info("Saving checkpoint to {}".format(filename))
with self.path_manager.open(filename, "wb") as f: with self.path_manager.open(filename, "wb") as f:
torch.save(data, cast(IO[bytes], f)) torch.save(data, cast(IO[bytes], f))
self.logger.info("Finished saving checkpoint to {}".format(filename))
def _load_file(self, f: str): def _load_file(self, f: str):
# Limit the read concurrency to avoid QPS overload # Limit the read concurrency to avoid QPS overload
with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()): with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment