Unverified Commit ad6c655d authored by Lionel Villard's avatar Lionel Villard Committed by GitHub
Browse files

preload heavy modules when mp method is forkserver (#22214)


Signed-off-by: default avatarLionel Villard <villard@us.ibm.com>
parent 14bcf93a
...@@ -13,7 +13,6 @@ import numpy as np ...@@ -13,7 +13,6 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
import vllm.envs as envs import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format,
write_to_json) write_to_json)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
...@@ -85,6 +84,9 @@ def main(args: argparse.Namespace): ...@@ -85,6 +84,9 @@ def main(args: argparse.Namespace):
"Please set it to a valid path to use torch profiler.") "Please set it to a valid path to use torch profiler.")
engine_args = EngineArgs.from_cli_args(args) engine_args = EngineArgs.from_cli_args(args)
# Lazy import to avoid importing LLM when the bench command is not selected.
from vllm import LLM, SamplingParams
# NOTE(woosuk): If the request cannot be processed in a single batch, # NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args)) llm = LLM(**dataclasses.asdict(engine_args))
......
...@@ -8,6 +8,7 @@ import importlib ...@@ -8,6 +8,7 @@ import importlib
import inspect import inspect
import json import json
import multiprocessing import multiprocessing
import multiprocessing.forkserver as forkserver
import os import os
import signal import signal
import socket import socket
...@@ -155,6 +156,15 @@ async def build_async_engine_client( ...@@ -155,6 +156,15 @@ async def build_async_engine_client(
client_config: Optional[dict[str, Any]] = None, client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]: ) -> AsyncIterator[EngineClient]:
if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
# The executor is expected to be mp.
# Pre-import heavy modules in the forkserver process
logger.debug("Setup forkserver with pre-imports")
multiprocessing.set_start_method('forkserver')
multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
forkserver.ensure_running()
logger.debug("Forkserver setup complete!")
# Context manager to handle engine_client lifecycle # Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment