Unverified Commit 8256833f authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

[Startup] Parallelize torch/transformers import + weight prefetch + forkserver prewarm (#40331)


Signed-off-by: default avatarsimon-mo <simon@inferact.ai>
parent 80975912
......@@ -5,10 +5,80 @@
Note that all future modules must be lazily loaded within main
to avoid certain eager import breakage."""
import contextlib
import importlib.metadata
import os
import sys
import threading as _threading
from vllm.logger import init_logger
# [startup] Kick off torch + transformers .so/module loading in a background
# thread before we touch vllm.logger (which pulls vllm/__init__.py ->
# vllm.env_override -> `import torch` on the main thread). Python import
# lock serializes the same-module import across threads, but the .so dlopen
# inside torch's init releases the GIL during file I/O. Main thread's
# non-torch imports (vllm.envs submodules, stdlib, fastapi, etc.) can make
# progress on the CPU while the background thread pays the ~2 s of cuda
# .so loading. `import transformers` is also ~2 s of cold-disk work and
# depends on torch; chain it after torch in the same thread so subsequent
# `from transformers import ...` lines on the main thread hit a warm
# module cache.
def _bg_preload_torch() -> None:
try:
import torch # noqa: F401
except Exception:
return
with contextlib.suppress(Exception):
import transformers # noqa: F401
_threading.Thread(
target=_bg_preload_torch, daemon=True, name="vllm-torch-preload"
).start()
# [startup] Pre-spawn EngineCore via forkserver preload, in a background
# thread. Only fires for `vllm serve` (the only subcommand that spawns a
# long-running EngineCore). The forkserver process is forked once and
# preloaded with vllm.v1.engine.async_llm (~3-5 s of imports). When
# AsyncLLM.from_vllm_config later runs, Process.start() forks from the
# already-warm forkserver instead of paying spawn() cost (~5 s in child
# for fresh Python + imports).
#
# Kicking the preload in a BG thread lets the ~3-5 s ensure_running cost
# overlap with APIServer's argparse + config resolution (~5-10 s on cold
# disk). Default cli_env_setup sets spawn; we override to forkserver
# before that runs so the path is consistent.
def _bg_prewarm_forkserver() -> None:
try:
import multiprocessing
import multiprocessing.forkserver as forkserver
# set_start_method MUST be called before ensure_running. It also
# can only be called once per process; any later override by
# vllm's build_async_engine_client will just see the existing
# setting.
multiprocessing.set_start_method("forkserver", force=False)
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
forkserver.ensure_running()
except Exception:
pass
if len(sys.argv) > 1 and sys.argv[1] == "serve":
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "forkserver")
# daemon=True so early CLI exits (bad args, --help, import errors)
# don't hang waiting for ensure_running(). The forkserver subprocess
# itself is tracked by module-level state in multiprocessing.forkserver
# and survives this thread exiting; subsequent spawn() calls reuse it.
_threading.Thread(
target=_bg_prewarm_forkserver,
daemon=True,
name="vllm-forkserver-prewarm",
).start()
from vllm.logger import init_logger # noqa: E402
logger = init_logger(__name__)
......
......@@ -12,7 +12,7 @@ import tempfile
import warnings
from argparse import Namespace
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from typing import Any
import uvloop
......@@ -74,6 +74,121 @@ logger = init_logger("vllm.entrypoints.openai.api_server")
_FALLBACK_SUPPORTED_TASKS: tuple[SupportedTask, ...] = ("generate",)
def _startup_prefetch_weights(vllm_config: "VllmConfig") -> None:
"""Kick off reading model weight shards into OS page cache from the
parent APIServer. EngineCore will read the same files a few seconds
later from the child; by then the kernel already has them ready.
All work (directory resolution, HF/ModelScope cache lookup, globbing,
and the reads themselves) runs inside the background thread so we do
not block the asyncio event loop.
Best-effort: any failure (unknown model location, permission, etc.) is
swallowed — vLLM's existing in-child prefetch then runs normally.
"""
import threading
# Capture only the small scalar fields the thread needs. Avoid holding
# a reference to vllm_config (which contains unpicklable objects) for
# longer than necessary.
model_ref = vllm_config.model_config.model
revision = vllm_config.model_config.revision
download_dir = vllm_config.load_config.download_dir
def _prefetch_worker() -> None:
import glob
import os
from vllm import envs
candidate_dir: str | None = None
# 1. Local path?
if os.path.isdir(model_ref):
candidate_dir = model_ref
else:
# 2. HF / ModelScope repo id — resolve to the local cache
# snapshot dir using the same revision / cache_dir the weight
# loader will use, so we prefetch the right files.
try:
if envs.VLLM_USE_MODELSCOPE:
from modelscope.hub.snapshot_download import (
snapshot_download,
)
candidate_dir = snapshot_download(
model_id=model_ref,
revision=revision,
cache_dir=download_dir,
local_files_only=True,
)
else:
from huggingface_hub import snapshot_download
candidate_dir = snapshot_download(
repo_id=model_ref,
revision=revision,
cache_dir=download_dir,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.json",
"*tokenizer*",
],
local_files_only=True,
)
except Exception:
return # not cached yet or not a known repo id
if not candidate_dir or not os.path.isdir(candidate_dir):
return
# Weight shards: large files, read into page cache.
shard_paths = sorted(
glob.glob(os.path.join(candidate_dir, "*.safetensors"))
+ glob.glob(os.path.join(candidate_dir, "*.bin"))
)
# Tokenizer/config sidecars: small, but re-opened in the child and
# add synchronous open+read latency when the disk is cold.
sidecar_paths = sorted(
glob.glob(os.path.join(candidate_dir, "*.json"))
+ glob.glob(os.path.join(candidate_dir, "tokenizer.model"))
+ glob.glob(os.path.join(candidate_dir, "*tokenizer*"))
)
shard_paths.extend(sidecar_paths)
if not shard_paths:
return
logger.debug(
"Parent-side weight prefetch starting for %d files in %s",
len(shard_paths),
candidate_dir,
)
# Match vLLM's in-child prefetch block size + thread count.
block_size = 16 * 1024 * 1024 # 16 MB
# Read shards in parallel across 8 worker threads (bounded) to
# saturate multi-spindle / multi-queue storage without thrashing.
from concurrent.futures import ThreadPoolExecutor
def read_one(p: str) -> None:
try:
with open(p, "rb") as f:
while f.read(block_size):
pass
except Exception:
pass
with ThreadPoolExecutor(max_workers=8) as pool:
list(pool.map(read_one, shard_paths))
threading.Thread(
target=_prefetch_worker,
daemon=True,
name="vllm-parent-weight-prefetch",
).start()
@asynccontextmanager
async def build_async_engine_client(
args: Namespace,
......@@ -85,7 +200,10 @@ async def build_async_engine_client(
# The executor is expected to be mp.
# Pre-import heavy modules in the forkserver process
logger.debug("Setup forkserver with pre-imports")
multiprocessing.set_start_method("forkserver")
# May already have been set by the CLI entry's async prewarm
# (vllm/entrypoints/cli/main.py); tolerate re-call.
with suppress(RuntimeError):
multiprocessing.set_start_method("forkserver", force=False)
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
forkserver.ensure_running()
logger.debug("Forkserver setup complete!")
......@@ -123,6 +241,28 @@ async def build_async_engine_client_from_engine_args(
# Create the EngineConfig (determines if we can use V1).
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# [startup] Start prefetching model weight shards into the OS page cache
# in a background thread from the PARENT APIServer process. EngineCore
# will page-fault on these same files ~10-15 s later (after fork + CUDA
# context + distributed init + model init). For large-weight cases
# (tens of GB) this parent-side head start meaningfully shrinks the
# prefetch+load phase that the engine's in-child prefetch otherwise
# barely overlaps.
#
# Skip in API-only workers that connect to an already-running EngineCore
# (multi-API-server / disaggregated setups): those processes never load
# weights, and if we prefetched from all of them we'd contend with the
# engine's own read. Presence of an `input_address` in client_config is
# the current marker that this worker is headless.
#
# Best-effort: if the model is a local path, glob for safetensors; if
# it's a repo-id, try to resolve via HF hub (or ModelScope) local cache.
# Any failure silently falls through to the existing in-child prefetch
# path. All I/O (incl. directory resolution) runs inside the BG thread
# so the asyncio event loop is never blocked.
if not (client_config and client_config.get("input_address")):
_startup_prefetch_weights(vllm_config)
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: AsyncLLM | None = None
......
......@@ -62,7 +62,7 @@ if TYPE_CHECKING:
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False
VLLM_XLA_USE_SPMD: bool = False
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork"
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn", "forkserver"] = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_ASSETS_CACHE_MODEL_CLEAN: bool = False
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
......@@ -765,9 +765,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "0"))
),
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
# spawn, fork, and forkserver all work. forkserver is opt-in for fast
# startup when paired with the CLI's async prewarm (see
# vllm/entrypoints/cli/main.py) — the forkserver process is preloaded
# with vllm.v1.engine.async_llm and a subsequent EngineCore Process.start()
# forks from that warm sibling instead of paying spawn cost.
"VLLM_WORKER_MULTIPROC_METHOD": env_with_choices(
"VLLM_WORKER_MULTIPROC_METHOD", "fork", ["spawn", "fork"]
"VLLM_WORKER_MULTIPROC_METHOD", "fork", ["spawn", "fork", "forkserver"]
),
# Path to the cache for storing downloaded assets
"VLLM_ASSETS_CACHE": lambda: os.path.expanduser(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment