Commit 4dc7e94c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

make interleave concurrency configurable

Reviewed By: mattcyu1

Differential Revision: D43557002

fbshipit-source-id: b929875f479b215b3e6034a03d8bea3e4cb3c2f8
parent cd9c320d
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import os
from typing import cast, IO
from typing import Callable, cast, IO
import detectron2.utils.comm as comm
import torch
......@@ -15,12 +15,25 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
)
def get_max_checkpoint_concurrency() -> int:
return comm.get_world_size()
# TODO: replace FSDPCheckpointer with central D2GoCheckpointer
class FSDPCheckpointer(QATCheckpointer):
"""
Extend the Checkpointer to support saving/loading FSDP models
"""
def __init__(
self,
*args,
concurrency_limit_fetcher: Callable[[], int] = get_max_checkpoint_concurrency,
**kwargs,
):
super().__init__(*args, **kwargs)
self._concurrency_limit_fetcher: Callable[[], int] = concurrency_limit_fetcher
def is_distributed(self) -> bool:
return True
......@@ -135,8 +148,10 @@ class FSDPCheckpointer(QATCheckpointer):
basename = "rank{}.pth".format(comm.get_rank())
save_file = os.path.join(new_save_dir, basename)
assert os.path.basename(save_file) == basename, basename
# allow 8 GPUs to write to manifold at the same time
with interleave_by_rank(concurrency_limit=8):
# Limit the write concurrency to avoid QPS overload
with interleave_by_rank(
concurrency_limit=self._concurrency_limit_fetcher()
):
self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes
comm.synchronize()
......@@ -156,8 +171,8 @@ class FSDPCheckpointer(QATCheckpointer):
torch.save(data, cast(IO[bytes], f))
def _load_file(self, f: str):
# allow 8 GPUs to read from manifold at the same time
with interleave_by_rank(concurrency_limit=8):
# Limit the read concurrency to avoid QPS overload
with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()):
return super()._load_file(f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment