Commit 54d9d91b authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

only let local master to download fsdp full checkpoint

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/626

Reviewed By: YanjunChen329

Differential Revision: D50135150

fbshipit-source-id: 6c85d4e966bb9e399c0fc17046fd1318bfbb1546
parent b375c290
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import json import json
import os import os
from contextlib import nullcontext
from typing import Callable, cast, IO from typing import Callable, cast, IO
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
...@@ -78,6 +79,9 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -78,6 +79,9 @@ class FSDPCheckpointer(QATCheckpointer):
"[FSDPCheckpointer] Loading from FULL_STATE_DICT checkpoint ..." "[FSDPCheckpointer] Loading from FULL_STATE_DICT checkpoint ..."
) )
self.model.load_state_dict_type = StateDictType.FULL_STATE_DICT self.model.load_state_dict_type = StateDictType.FULL_STATE_DICT
# since checkpoints are the same across ranks, we can download from
# rank0 and shared the local file.
load_path = self._get_local_path_per_host(load_path)
_log_api_usage_on_main_process( _log_api_usage_on_main_process(
f"{LOG_API_IDENTIFIER}.load.fsdp.{self.model.load_state_dict_type.name}" # pyre-ignore f"{LOG_API_IDENTIFIER}.load.fsdp.{self.model.load_state_dict_type.name}" # pyre-ignore
...@@ -216,8 +220,12 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -216,8 +220,12 @@ class FSDPCheckpointer(QATCheckpointer):
self.logger.info("Finished saving checkpoint to {}".format(filename)) self.logger.info("Finished saving checkpoint to {}".format(filename))
def _load_file(self, f: str): def _load_file(self, f: str):
# Limit the read concurrency to avoid QPS overload with (
with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()): interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher())
if isinstance(self.model, FSDPWrapper)
and self.model.state_dict_type != StateDictType.FULL_STATE_DICT
else nullcontext() # FULL_STATE_DICT doesn't need interleaving
):
return super()._load_file(f) return super()._load_file(f)
def _save_metadata(self, path): def _save_metadata(self, path):
...@@ -233,3 +241,33 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -233,3 +241,33 @@ class FSDPCheckpointer(QATCheckpointer):
return json.load(f) return json.load(f)
else: else:
return None return None
def _get_local_path_per_host(self, path: str) -> str:
"""Download file only on local master, return the downloaded path for all ranks"""
from torchtnt.utils.distributed import get_local_rank, get_local_world_size
self.logger.info("Start getting local path per host ...")
# check if paths are the same on the same node
all_paths = comm.all_gather(path)
local_master = (
comm.get_rank() // get_local_world_size() * get_local_world_size()
)
if path != all_paths[local_master]:
raise ValueError(
f"All paths must be the same on the same node, got {path} vs {all_paths[local_master]}"
)
# local master downloads the file, while non-master skips
if get_local_rank() == 0:
self.logger.info(f"Start downloading {path} to local file ...")
local_path = self.path_manager.get_local_path(path)
self.logger.info(f"Finished downloading {path} to local file: {local_path}")
else:
local_path = None
self.logger.info("Waiting for local master to finish downloading ...")
# broadcast the local path to all other ranks
local_paths = comm.all_gather(local_path)
local_path = local_paths[local_master]
assert local_path is not None, f"Local path is None, {local_paths=}"
self.logger.info("Finished getting local path per host")
return local_path
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment