Commit 3111ae59 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

turn off interleaving if only saving on rank0

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/482

We should avoid using interleaving during save if we are calling save on one process:
```
if comm.is_main_process():
  save()
```
this is because interleave calls comm.synchronize() so will just wait indefinitely.

This diff updates the FSDP checkpointer to use save(interleave=False) when running on one process.

Reviewed By: wat3rBro, YanjunChen329

Differential Revision: D43526328

fbshipit-source-id: 672993a87af627aca090384b0c218798bd42fcde
parent 4e4a865c
......@@ -135,7 +135,9 @@ class FSDPCheckpointer(QATCheckpointer):
basename = "rank{}.pth".format(comm.get_rank())
save_file = os.path.join(new_save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self._save_file(data, save_file)
# allow 8 GPUs to write to manifold at the same time
with interleave_by_rank(concurrency_limit=8):
self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes
comm.synchronize()
if comm.is_main_process() and tag_last_ckpt:
......@@ -149,11 +151,9 @@ class FSDPCheckpointer(QATCheckpointer):
self.tag_last_checkpoint(basename)
def _save_file(self, data, filename):
# allow 8 GPUs to write to manifold at the same time
with interleave_by_rank(concurrency_limit=8):
self.logger.info("Saving checkpoint to {}".format(filename))
with self.path_manager.open(filename, "wb") as f:
torch.save(data, cast(IO[bytes], f))
self.logger.info("Saving checkpoint to {}".format(filename))
with self.path_manager.open(filename, "wb") as f:
torch.save(data, cast(IO[bytes], f))
def _load_file(self, f: str):
# allow 8 GPUs to read from manifold at the same time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment