"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6b221920d7542fc768244bd06ca919dffdd56ed0"
Commit 4dc7e94c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

make interleave concurrency configurable

Reviewed By: mattcyu1

Differential Revision: D43557002

fbshipit-source-id: b929875f479b215b3e6034a03d8bea3e4cb3c2f8
parent cd9c320d
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import os import os
from typing import cast, IO from typing import Callable, cast, IO
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
...@@ -15,12 +15,25 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import ( ...@@ -15,12 +15,25 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import (
) )
def get_max_checkpoint_concurrency() -> int:
return comm.get_world_size()
# TODO: replace FSDPCheckpointer with central D2GoCheckpointer # TODO: replace FSDPCheckpointer with central D2GoCheckpointer
class FSDPCheckpointer(QATCheckpointer): class FSDPCheckpointer(QATCheckpointer):
""" """
Extend the Checkpointer to support saving/loading FSDP models Extend the Checkpointer to support saving/loading FSDP models
""" """
def __init__(
self,
*args,
concurrency_limit_fetcher: Callable[[], int] = get_max_checkpoint_concurrency,
**kwargs,
):
super().__init__(*args, **kwargs)
self._concurrency_limit_fetcher: Callable[[], int] = concurrency_limit_fetcher
def is_distributed(self) -> bool: def is_distributed(self) -> bool:
return True return True
...@@ -135,8 +148,10 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -135,8 +148,10 @@ class FSDPCheckpointer(QATCheckpointer):
basename = "rank{}.pth".format(comm.get_rank()) basename = "rank{}.pth".format(comm.get_rank())
save_file = os.path.join(new_save_dir, basename) save_file = os.path.join(new_save_dir, basename)
assert os.path.basename(save_file) == basename, basename assert os.path.basename(save_file) == basename, basename
# allow 8 GPUs to write to manifold at the same time # Limit the write concurrency to avoid QPS overload
with interleave_by_rank(concurrency_limit=8): with interleave_by_rank(
concurrency_limit=self._concurrency_limit_fetcher()
):
self._save_file(data, save_file) self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes # Main process tags last checkpoint if no errors in all processes
comm.synchronize() comm.synchronize()
...@@ -156,8 +171,8 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -156,8 +171,8 @@ class FSDPCheckpointer(QATCheckpointer):
torch.save(data, cast(IO[bytes], f)) torch.save(data, cast(IO[bytes], f))
def _load_file(self, f: str): def _load_file(self, f: str):
# allow 8 GPUs to read from manifold at the same time # Limit the read concurrency to avoid QPS overload
with interleave_by_rank(concurrency_limit=8): with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()):
return super()._load_file(f) return super()._load_file(f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment