"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "262d263f6c56fa95e15422d3a475da8efdf67cc1"
Unverified Commit 8256833f authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

[Startup] Parallelize torch/transformers import + weight prefetch + forkserver prewarm (#40331)


Signed-off-by: default avatarsimon-mo <simon@inferact.ai>
parent 80975912
...@@ -5,10 +5,80 @@ ...@@ -5,10 +5,80 @@
Note that all future modules must be lazily loaded within main Note that all future modules must be lazily loaded within main
to avoid certain eager import breakage.""" to avoid certain eager import breakage."""
import contextlib
import importlib.metadata import importlib.metadata
import os
import sys import sys
import threading as _threading
from vllm.logger import init_logger
# [startup] Kick off torch + transformers .so/module loading in a background
# thread before we touch vllm.logger (which pulls vllm/__init__.py ->
# vllm.env_override -> `import torch` on the main thread). Python import
# lock serializes the same-module import across threads, but the .so dlopen
# inside torch's init releases the GIL during file I/O. Main thread's
# non-torch imports (vllm.envs submodules, stdlib, fastapi, etc.) can make
# progress on the CPU while the background thread pays the ~2 s of cuda
# .so loading. `import transformers` is also ~2 s of cold-disk work and
# depends on torch; chain it after torch in the same thread so subsequent
# `from transformers import ...` lines on the main thread hit a warm
# module cache.
def _bg_preload_torch() -> None:
try:
import torch # noqa: F401
except Exception:
return
with contextlib.suppress(Exception):
import transformers # noqa: F401
_threading.Thread(
target=_bg_preload_torch, daemon=True, name="vllm-torch-preload"
).start()
# [startup] Pre-spawn EngineCore via forkserver preload, in a background
# thread. Only fires for `vllm serve` (the only subcommand that spawns a
# long-running EngineCore). The forkserver process is forked once and
# preloaded with vllm.v1.engine.async_llm (~3-5 s of imports). When
# AsyncLLM.from_vllm_config later runs, Process.start() forks from the
# already-warm forkserver instead of paying spawn() cost (~5 s in child
# for fresh Python + imports).
#
# Kicking the preload in a BG thread lets the ~3-5 s ensure_running cost
# overlap with APIServer's argparse + config resolution (~5-10 s on cold
# disk). Default cli_env_setup sets spawn; we override to forkserver
# before that runs so the path is consistent.
def _bg_prewarm_forkserver() -> None:
try:
import multiprocessing
import multiprocessing.forkserver as forkserver
# set_start_method MUST be called before ensure_running. It also
# can only be called once per process; any later override by
# vllm's build_async_engine_client will just see the existing
# setting.
multiprocessing.set_start_method("forkserver", force=False)
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
forkserver.ensure_running()
except Exception:
pass
if len(sys.argv) > 1 and sys.argv[1] == "serve":
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "forkserver")
# daemon=True so early CLI exits (bad args, --help, import errors)
# don't hang waiting for ensure_running(). The forkserver subprocess
# itself is tracked by module-level state in multiprocessing.forkserver
# and survives this thread exiting; subsequent spawn() calls reuse it.
_threading.Thread(
target=_bg_prewarm_forkserver,
daemon=True,
name="vllm-forkserver-prewarm",
).start()
from vllm.logger import init_logger # noqa: E402
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -12,7 +12,7 @@ import tempfile ...@@ -12,7 +12,7 @@ import tempfile
import warnings import warnings
from argparse import Namespace from argparse import Namespace
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager, suppress
from typing import Any from typing import Any
import uvloop import uvloop
...@@ -74,6 +74,121 @@ logger = init_logger("vllm.entrypoints.openai.api_server") ...@@ -74,6 +74,121 @@ logger = init_logger("vllm.entrypoints.openai.api_server")
_FALLBACK_SUPPORTED_TASKS: tuple[SupportedTask, ...] = ("generate",) _FALLBACK_SUPPORTED_TASKS: tuple[SupportedTask, ...] = ("generate",)
def _startup_prefetch_weights(vllm_config: "VllmConfig") -> None:
"""Kick off reading model weight shards into OS page cache from the
parent APIServer. EngineCore will read the same files a few seconds
later from the child; by then the kernel already has them ready.
All work (directory resolution, HF/ModelScope cache lookup, globbing,
and the reads themselves) runs inside the background thread so we do
not block the asyncio event loop.
Best-effort: any failure (unknown model location, permission, etc.) is
swallowed — vLLM's existing in-child prefetch then runs normally.
"""
import threading
# Capture only the small scalar fields the thread needs. Avoid holding
# a reference to vllm_config (which contains unpicklable objects) for
# longer than necessary.
model_ref = vllm_config.model_config.model
revision = vllm_config.model_config.revision
download_dir = vllm_config.load_config.download_dir
def _prefetch_worker() -> None:
import glob
import os
from vllm import envs
candidate_dir: str | None = None
# 1. Local path?
if os.path.isdir(model_ref):
candidate_dir = model_ref
else:
# 2. HF / ModelScope repo id — resolve to the local cache
# snapshot dir using the same revision / cache_dir the weight
# loader will use, so we prefetch the right files.
try:
if envs.VLLM_USE_MODELSCOPE:
from modelscope.hub.snapshot_download import (
snapshot_download,
)
candidate_dir = snapshot_download(
model_id=model_ref,
revision=revision,
cache_dir=download_dir,
local_files_only=True,
)
else:
from huggingface_hub import snapshot_download
candidate_dir = snapshot_download(
repo_id=model_ref,
revision=revision,
cache_dir=download_dir,
allow_patterns=[
"*.safetensors",
"*.bin",
"*.json",
"*tokenizer*",
],
local_files_only=True,
)
except Exception:
return # not cached yet or not a known repo id
if not candidate_dir or not os.path.isdir(candidate_dir):
return
# Weight shards: large files, read into page cache.
shard_paths = sorted(
glob.glob(os.path.join(candidate_dir, "*.safetensors"))
+ glob.glob(os.path.join(candidate_dir, "*.bin"))
)
# Tokenizer/config sidecars: small, but re-opened in the child and
# add synchronous open+read latency when the disk is cold.
sidecar_paths = sorted(
glob.glob(os.path.join(candidate_dir, "*.json"))
+ glob.glob(os.path.join(candidate_dir, "tokenizer.model"))
+ glob.glob(os.path.join(candidate_dir, "*tokenizer*"))
)
shard_paths.extend(sidecar_paths)
if not shard_paths:
return
logger.debug(
"Parent-side weight prefetch starting for %d files in %s",
len(shard_paths),
candidate_dir,
)
# Match vLLM's in-child prefetch block size + thread count.
block_size = 16 * 1024 * 1024 # 16 MB
# Read shards in parallel across 8 worker threads (bounded) to
# saturate multi-spindle / multi-queue storage without thrashing.
from concurrent.futures import ThreadPoolExecutor
def read_one(p: str) -> None:
try:
with open(p, "rb") as f:
while f.read(block_size):
pass
except Exception:
pass
with ThreadPoolExecutor(max_workers=8) as pool:
list(pool.map(read_one, shard_paths))
threading.Thread(
target=_prefetch_worker,
daemon=True,
name="vllm-parent-weight-prefetch",
).start()
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client( async def build_async_engine_client(
args: Namespace, args: Namespace,
...@@ -85,7 +200,10 @@ async def build_async_engine_client( ...@@ -85,7 +200,10 @@ async def build_async_engine_client(
# The executor is expected to be mp. # The executor is expected to be mp.
# Pre-import heavy modules in the forkserver process # Pre-import heavy modules in the forkserver process
logger.debug("Setup forkserver with pre-imports") logger.debug("Setup forkserver with pre-imports")
multiprocessing.set_start_method("forkserver") # May already have been set by the CLI entry's async prewarm
# (vllm/entrypoints/cli/main.py); tolerate re-call.
with suppress(RuntimeError):
multiprocessing.set_start_method("forkserver", force=False)
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"]) multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
forkserver.ensure_running() forkserver.ensure_running()
logger.debug("Forkserver setup complete!") logger.debug("Forkserver setup complete!")
...@@ -123,6 +241,28 @@ async def build_async_engine_client_from_engine_args( ...@@ -123,6 +241,28 @@ async def build_async_engine_client_from_engine_args(
# Create the EngineConfig (determines if we can use V1). # Create the EngineConfig (determines if we can use V1).
vllm_config = engine_args.create_engine_config(usage_context=usage_context) vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# [startup] Start prefetching model weight shards into the OS page cache
# in a background thread from the PARENT APIServer process. EngineCore
# will page-fault on these same files ~10-15 s later (after fork + CUDA
# context + distributed init + model init). For large-weight cases
# (tens of GB) this parent-side head start meaningfully shrinks the
# prefetch+load phase that the engine's in-child prefetch otherwise
# barely overlaps.
#
# Skip in API-only workers that connect to an already-running EngineCore
# (multi-API-server / disaggregated setups): those processes never load
# weights, and if we prefetched from all of them we'd contend with the
# engine's own read. Presence of an `input_address` in client_config is
# the current marker that this worker is headless.
#
# Best-effort: if the model is a local path, glob for safetensors; if
# it's a repo-id, try to resolve via HF hub (or ModelScope) local cache.
# Any failure silently falls through to the existing in-child prefetch
# path. All I/O (incl. directory resolution) runs inside the BG thread
# so the asyncio event loop is never blocked.
if not (client_config and client_config.get("input_address")):
_startup_prefetch_weights(vllm_config)
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
async_llm: AsyncLLM | None = None async_llm: AsyncLLM | None = None
......
...@@ -62,7 +62,7 @@ if TYPE_CHECKING: ...@@ -62,7 +62,7 @@ if TYPE_CHECKING:
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False
VLLM_XLA_USE_SPMD: bool = False VLLM_XLA_USE_SPMD: bool = False
VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork" VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn", "forkserver"] = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_ASSETS_CACHE_MODEL_CLEAN: bool = False VLLM_ASSETS_CACHE_MODEL_CLEAN: bool = False
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
...@@ -765,9 +765,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -765,9 +765,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "0")) int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "0"))
), ),
# Use dedicated multiprocess context for workers. # Use dedicated multiprocess context for workers.
# Both spawn and fork work # spawn, fork, and forkserver all work. forkserver is opt-in for fast
# startup when paired with the CLI's async prewarm (see
# vllm/entrypoints/cli/main.py) — the forkserver process is preloaded
# with vllm.v1.engine.async_llm and a subsequent EngineCore Process.start()
# forks from that warm sibling instead of paying spawn cost.
"VLLM_WORKER_MULTIPROC_METHOD": env_with_choices( "VLLM_WORKER_MULTIPROC_METHOD": env_with_choices(
"VLLM_WORKER_MULTIPROC_METHOD", "fork", ["spawn", "fork"] "VLLM_WORKER_MULTIPROC_METHOD", "fork", ["spawn", "fork", "forkserver"]
), ),
# Path to the cache for storing downloaded assets # Path to the cache for storing downloaded assets
"VLLM_ASSETS_CACHE": lambda: os.path.expanduser( "VLLM_ASSETS_CACHE": lambda: os.path.expanduser(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment