Unverified Commit f5e59ee7 authored by Artem Perevedentsev's avatar Artem Perevedentsev Committed by GitHub
Browse files

[Performance] Add prefetch for checkpoints to OS page cache (#36012)


Signed-off-by: default avatarArtem Perevedentsev <aperevedents@nvidia.com>
parent 9b005edc
...@@ -62,6 +62,9 @@ class LoadConfig: ...@@ -62,6 +62,9 @@ class LoadConfig:
This is recommended for models on network filesystems (e.g., Lustre, NFS) This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM. initialization. However, it uses more CPU RAM.
- "prefetch": Checkpoint files are read into the OS page cache before
workers load them, speeding up the model loading phase. Useful on
network or high-latency storage.
- "torchao": Weights are loaded in upfront and then reconstructed - "torchao": Weights are loaded in upfront and then reconstructed
into torchao tensor subclasses. This is used when the checkpoint into torchao tensor subclasses. This is used when the checkpoint
was quantized using torchao and saved using safetensors. was quantized using torchao and saved using safetensors.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for downloading and initializing model weights.""" """Utilities for downloading and initializing model weights."""
import asyncio
import concurrent.futures import concurrent.futures
import fnmatch import fnmatch
import glob import glob
...@@ -9,6 +10,7 @@ import hashlib ...@@ -9,6 +10,7 @@ import hashlib
import json import json
import os import os
import tempfile import tempfile
import threading
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
...@@ -720,6 +722,71 @@ def np_cache_weights_iterator( ...@@ -720,6 +722,71 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param) yield name, torch.from_numpy(param)
def _prefetch_checkpoint(file_path: str) -> None:
"""Prefetch a checkpoint file into the OS page cache.
Reads the file in 16MB blocks so the kernel caches its pages before
workers load the same file.
"""
block_size = 16 * 1024 * 1024 # 16MB
with open(file_path, "rb") as f:
while f.read(block_size):
pass
def _prefetch_all_checkpoints(sorted_files: list[str]) -> None:
"""Start prefetching checkpoint files into page cache in a background thread."""
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
rank = 0
world_size = 1
num_prefetch_threads = 8
paths_to_prefetch = sorted_files[rank::world_size]
total_for_rank = len(paths_to_prefetch)
async def _prefetch_all() -> None:
semaphore = asyncio.Semaphore(num_prefetch_threads)
completed = 0
next_log_pct = 10
async def prefetch_one(path: str) -> None:
nonlocal completed, next_log_pct
try:
async with semaphore:
await asyncio.to_thread(_prefetch_checkpoint, path)
completed += 1
if total_for_rank > 0 and next_log_pct <= 100:
pct = 100 * completed / total_for_rank
if pct >= next_log_pct:
logger.info(
"Prefetching checkpoint files: %d%% (%d/%d)",
next_log_pct,
completed,
total_for_rank,
)
next_log_pct += 10
except Exception:
logger.warning(
"Failed to prefetch checkpoint file %r.", path, exc_info=True
)
await asyncio.gather(*(prefetch_one(p) for p in paths_to_prefetch))
def _run_prefetch() -> None:
start = time.perf_counter()
asyncio.run(_prefetch_all())
elapsed = time.perf_counter() - start
logger.info(
"Prefetching checkpoint files into page cache finished in %.2fs",
elapsed,
)
logger.info("Prefetching checkpoint files into page cache started (in background)")
threading.Thread(target=_run_prefetch, daemon=True).start()
def safetensors_weights_iterator( def safetensors_weights_iterator(
hf_weights_files: list[str], hf_weights_files: list[str],
use_tqdm_on_load: bool, use_tqdm_on_load: bool,
...@@ -736,9 +803,14 @@ def safetensors_weights_iterator( ...@@ -736,9 +803,14 @@ def safetensors_weights_iterator(
if safetensors_load_strategy == "eager": if safetensors_load_strategy == "eager":
loading_desc += " (eager)" loading_desc += " (eager)"
sorted_files = sorted(hf_weights_files, key=_natural_sort_key)
if safetensors_load_strategy == "prefetch":
_prefetch_all_checkpoints(sorted_files)
leftover_state_dict: dict[str, torch.Tensor] = {} leftover_state_dict: dict[str, torch.Tensor] = {}
for st_file in tqdm( for st_file in tqdm(
sorted(hf_weights_files, key=_natural_sort_key), sorted_files,
desc=loading_desc, desc=loading_desc,
disable=not enable_tqdm(use_tqdm_on_load), disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment