Commit 54d9d91b authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

only let local master to download fsdp full checkpoint

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/626

Reviewed By: YanjunChen329

Differential Revision: D50135150

fbshipit-source-id: 6c85d4e966bb9e399c0fc17046fd1318bfbb1546
parent b375c290
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import json
import os
from contextlib import nullcontext
from typing import Callable, cast, IO
import detectron2.utils.comm as comm
......@@ -78,6 +79,9 @@ class FSDPCheckpointer(QATCheckpointer):
"[FSDPCheckpointer] Loading from FULL_STATE_DICT checkpoint ..."
)
self.model.load_state_dict_type = StateDictType.FULL_STATE_DICT
# since checkpoints are the same across ranks, we can download from
# rank0 and shared the local file.
load_path = self._get_local_path_per_host(load_path)
_log_api_usage_on_main_process(
f"{LOG_API_IDENTIFIER}.load.fsdp.{self.model.load_state_dict_type.name}" # pyre-ignore
......@@ -216,8 +220,12 @@ class FSDPCheckpointer(QATCheckpointer):
self.logger.info("Finished saving checkpoint to {}".format(filename))
def _load_file(self, f: str):
# Limit the read concurrency to avoid QPS overload
with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()):
with (
interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher())
if isinstance(self.model, FSDPWrapper)
and self.model.state_dict_type != StateDictType.FULL_STATE_DICT
else nullcontext() # FULL_STATE_DICT doesn't need interleaving
):
return super()._load_file(f)
def _save_metadata(self, path):
......@@ -233,3 +241,33 @@ class FSDPCheckpointer(QATCheckpointer):
return json.load(f)
else:
return None
def _get_local_path_per_host(self, path: str) -> str:
"""Download file only on local master, return the downloaded path for all ranks"""
from torchtnt.utils.distributed import get_local_rank, get_local_world_size
self.logger.info("Start getting local path per host ...")
# check if paths are the same on the same node
all_paths = comm.all_gather(path)
local_master = (
comm.get_rank() // get_local_world_size() * get_local_world_size()
)
if path != all_paths[local_master]:
raise ValueError(
f"All paths must be the same on the same node, got {path} vs {all_paths[local_master]}"
)
# local master downloads the file, while non-master skips
if get_local_rank() == 0:
self.logger.info(f"Start downloading {path} to local file ...")
local_path = self.path_manager.get_local_path(path)
self.logger.info(f"Finished downloading {path} to local file: {local_path}")
else:
local_path = None
self.logger.info("Waiting for local master to finish downloading ...")
# broadcast the local path to all other ranks
local_paths = comm.all_gather(local_path)
local_path = local_paths[local_master]
assert local_path is not None, f"Local path is None, {local_paths=}"
self.logger.info("Finished getting local path per host")
return local_path
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment