Unverified Commit 657f2f30 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[DP] Support external DP Load Balancer mode (#19790)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent a1aafc82
...@@ -155,6 +155,7 @@ steps: ...@@ -155,6 +155,7 @@ steps:
- examples/offline_inference/rlhf_colocate.py - examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py - tests/examples/offline_inference/data_parallel.py
- tests/v1/test_async_llm_dp.py - tests/v1/test_async_llm_dp.py
- tests/v1/test_external_lb_dp.py
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
commands: commands:
# test with tp=2 and external_dp=2 # test with tp=2 and external_dp=2
...@@ -163,8 +164,9 @@ steps: ...@@ -163,8 +164,9 @@ steps:
# test with tp=2 and pp=2 # test with tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp # test with internal dp
- python3 ../examples/offline_inference/data_parallel.py - python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py - pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
...@@ -682,10 +684,12 @@ steps: ...@@ -682,10 +684,12 @@ steps:
- vllm/worker/model_runner.py - vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py - entrypoints/llm/test_collective_rpc.py
- tests/v1/test_async_llm_dp.py - tests/v1/test_async_llm_dp.py
- tests/v1/test_external_lb_dp.py
- tests/v1/entrypoints/openai/test_multi_api_servers.py - tests/v1/entrypoints/openai/test_multi_api_servers.py
- vllm/v1/engine/ - vllm/v1/engine/
commands: commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
......
...@@ -26,8 +26,8 @@ from vllm.v1.engine import EngineCoreRequest ...@@ -26,8 +26,8 @@ from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient) SyncMPClient)
from vllm.v1.engine.utils import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import CoreEngineProcManager
from ...distributed.conftest import MockSubscriber from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test from ...utils import create_new_process_for_each_test
...@@ -563,7 +563,7 @@ def test_engine_core_proc_instantiation_cuda_empty( ...@@ -563,7 +563,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
from vllm.v1.utils import EngineZmqAddresses from vllm.v1.engine.utils import EngineZmqAddresses
def mock_startup_handshake(self, handshake_socket, on_head_node, def mock_startup_handshake(self, handshake_socket, on_head_node,
parallel_config): parallel_config):
...@@ -580,7 +580,7 @@ def test_engine_core_proc_instantiation_cuda_empty( ...@@ -580,7 +580,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
trust_remote_code=True).create_engine_config() trust_remote_code=True).create_engine_config()
engine_core_proc = EngineCoreProc( engine_core_proc = EngineCoreProc(
vllm_config=vllm_config, vllm_config=vllm_config,
on_head_node=True, local_client=True,
handshake_address="tcp://127.0.0.1:12345", handshake_address="tcp://127.0.0.1:12345",
executor_class=mock_executor_class, executor_class=mock_executor_class,
log_stats=False, log_stats=False,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import threading
import time
from contextlib import AsyncExitStack
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.platforms import Platform
MODEL_NAME = "ibm-research/PowerMoE-3b"
# Number of data parallel ranks for external LB testing
DP_SIZE = int(os.getenv("DP_SIZE", "2"))
# Default tensor parallell size to use
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
class ExternalLBServerManager:
"""Manages data parallel vLLM server instances for external
load balancer testing."""
def __init__(self,
model_name: str,
dp_size: int,
api_server_count: int,
base_server_args: list,
tp_size: int = TP_SIZE):
self.model_name = model_name
self.dp_size = dp_size
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.server_threads: list[threading.Thread] = []
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for external LB mode."""
for rank in range(self.dp_size):
# Create server args for this specific rank
server_args = self.base_server_args.copy()
# Add external LB specific arguments
server_args.extend([
"--data-parallel-size",
str(self.dp_size),
"--data-parallel-rank",
str(rank),
"--data-parallel-size-local",
"1",
"--tensor-parallel-size",
str(self.tp_size),
"--port",
str(8000 + rank), # Different port for each rank
"--api-server-count",
str(self.api_server_count),
])
# Use a thread to start each server to allow parallel initialization
def start_server(r: int, sargs: list[str]):
try:
# Start the server
server = RemoteOpenAIServer(
self.model_name,
sargs,
auto_port=False,
env_dict={
"CUDA_VISIBLE_DEVICES":
",".join(
str(Platform.device_id_to_physical_device_id(
i))
for i in range(r * TP_SIZE, (r + 1) * TP_SIZE))
})
server.__enter__()
print(f"Server rank {r} started successfully with "
f"{self.api_server_count} API servers")
self.servers.append((server, sargs))
except Exception as e:
print(f"Failed to start server rank {r}: {e}")
raise
thread = threading.Thread(target=start_server,
args=(rank, server_args))
thread.start()
self.server_threads.append(thread)
# Wait for all servers to start
for thread in self.server_threads:
thread.join()
# Give servers additional time to fully initialize and coordinate
time.sleep(2)
if len(self.servers) != self.dp_size:
raise Exception("Servers failed to start")
return self.servers
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
]
@pytest.fixture(scope="module", params=[1, 4])
def servers(request, default_server_args):
api_server_count = request.param
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args) as server_list:
yield server_list
@pytest_asyncio.fixture
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
# Create a client for each server
async with AsyncExitStack() as stack:
yield [
await stack.enter_async_context(server.get_async_client())
for server, _ in servers
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_external_lb_single_completion(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
async def make_request(client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=10,
temperature=1.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request to each server
for i, client in enumerate(clients):
result = await make_request(client)
assert result is not None
print(f"Server {i} handled single completion request successfully")
await asyncio.sleep(0.5)
# Send requests to all servers in round-robin fashion
num_requests_per_server = 25 # Total 50 requests across 2 servers
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_server)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
# Second burst of requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_server)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(completion is not None for completion in results)
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(
f"Successfully completed external LB test with {len(clients)} servers "
f"(API server count: {api_server_count})")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_external_lb_completion_streaming(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
prompt = "What is an LLM?"
async def make_streaming_request(client: openai.AsyncOpenAI):
# Perform a non-streaming request to get the expected full output
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
# Perform the streaming request
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
last_chunk = None
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
last_chunk = chunk # Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert finish_reason_count == 1, (
"Finish reason should appear exactly once.")
assert last_chunk is not None, (
"Stream should have yielded at least one chunk.")
assert last_chunk.choices[
0].finish_reason == "length", "Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert "".join(
chunks
) == single_output, "Streamed output should match non-streamed output."
return True # Indicate success for this request
# Test single request to each server
for i, client in enumerate(clients):
result = await make_streaming_request(client)
assert result is not None
print(f"Server {i} handled single streaming request successfully")
await asyncio.sleep(0.5)
# Send streaming requests to all servers in round-robin fashion
num_requests_per_server = 25 # Total 50 requests across 2 servers
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_server)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(results), "Not all streaming requests completed successfully."
await asyncio.sleep(0.5)
# Second burst of streaming requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_server)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(results), "Not all streaming requests completed successfully."
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(f"Successfully completed external LB streaming test with "
f"{len(clients)} servers (API server count: {api_server_count})")
...@@ -1784,6 +1784,10 @@ class ParallelConfig: ...@@ -1784,6 +1784,10 @@ class ParallelConfig:
"""Port of the data parallel master.""" """Port of the data parallel master."""
data_parallel_backend: str = "mp" data_parallel_backend: str = "mp"
"""Backend to use for data parallel, either "mp" or "ray".""" """Backend to use for data parallel, either "mp" or "ray"."""
data_parallel_external_lb: bool = False
"""Whether to use "external" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Set implicitly when
data_parallel_rank is provided explicitly to vllm serve."""
enable_expert_parallel: bool = False enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers.""" """Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False enable_eplb: bool = False
...@@ -1953,6 +1957,11 @@ class ParallelConfig: ...@@ -1953,6 +1957,11 @@ class ParallelConfig:
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args. # Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port() self.data_parallel_master_port = get_open_port()
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError(
f"data_parallel_rank ({self.data_parallel_rank})"
f" must be in the range [0, {self.data_parallel_size})")
else: else:
# Otherwise fall back to env vars (e.g. for offline SPMD case). # Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE self.data_parallel_size = envs.VLLM_DP_SIZE
...@@ -1961,6 +1970,10 @@ class ParallelConfig: ...@@ -1961,6 +1970,10 @@ class ParallelConfig:
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
if self.data_parallel_external_lb:
raise ValueError("data_parallel_external_lb can only "
"be set when data_parallel_size > 1")
if self.distributed_executor_backend == "external_launcher": if self.distributed_executor_backend == "external_launcher":
import os import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
......
...@@ -318,6 +318,7 @@ class EngineArgs: ...@@ -318,6 +318,7 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None
data_parallel_size_local: Optional[int] = None data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None data_parallel_rpc_port: Optional[int] = None
...@@ -655,6 +656,12 @@ class EngineArgs: ...@@ -655,6 +656,12 @@ class EngineArgs:
**parallel_kwargs["tensor_parallel_size"]) **parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp", parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"]) **parallel_kwargs["data_parallel_size"])
parallel_group.add_argument(
'--data-parallel-rank',
'-dpn',
type=int,
help='Data parallel rank of this instance. '
'When set, enables external load balancer mode.')
parallel_group.add_argument('--data-parallel-size-local', parallel_group.add_argument('--data-parallel-size-local',
'-dpl', '-dpl',
type=int, type=int,
...@@ -1126,10 +1133,17 @@ class EngineArgs: ...@@ -1126,10 +1133,17 @@ class EngineArgs:
# but we should not do this here. # but we should not do this here.
placement_group = ray.util.get_current_placement_group() placement_group = ray.util.get_current_placement_group()
# Local DP size defaults to global DP size if not set. data_parallel_external_lb = self.data_parallel_rank is not None
data_parallel_size_local = self.data_parallel_size if ( if data_parallel_external_lb:
self.data_parallel_size_local assert self.data_parallel_size_local in (1, None), (
is None) else self.data_parallel_size_local "data_parallel_size_local must be 1 when data_parallel_rank "
"is set")
data_parallel_size_local = 1
elif self.data_parallel_size_local is not None:
data_parallel_size_local = self.data_parallel_size_local
else:
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size
# DP address, used in multi-node case for torch distributed group # DP address, used in multi-node case for torch distributed group
# and ZMQ sockets. # and ZMQ sockets.
...@@ -1154,16 +1168,16 @@ class EngineArgs: ...@@ -1154,16 +1168,16 @@ class EngineArgs:
self.data_parallel_rpc_port self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port is not None) else ParallelConfig.data_parallel_rpc_port
data_parallel_backend = self.data_parallel_backend
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size, pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size, data_parallel_size=self.data_parallel_size,
data_parallel_rank=self.data_parallel_rank or 0,
data_parallel_external_lb=data_parallel_external_lb,
data_parallel_size_local=data_parallel_size_local, data_parallel_size_local=data_parallel_size_local,
data_parallel_master_ip=data_parallel_address, data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=data_parallel_backend, data_parallel_backend=self.data_parallel_backend,
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts, num_redundant_experts=self.num_redundant_experts,
......
...@@ -5,9 +5,9 @@ import argparse ...@@ -5,9 +5,9 @@ import argparse
import os import os
import signal import signal
import sys import sys
from typing import Optional
import uvloop import uvloop
import zmq
import vllm import vllm
import vllm.envs as envs import vllm.envs as envs
...@@ -21,17 +21,13 @@ from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, ...@@ -21,17 +21,13 @@ from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx from vllm.utils import FlexibleArgumentParser, get_tcp_uri
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import (APIServerProcessManager, CoreEngine, from vllm.v1.utils import (APIServerProcessManager,
CoreEngineActorManager, EngineZmqAddresses, wait_for_completion_or_failure)
get_engine_client_zmq_addr,
wait_for_completion_or_failure,
wait_for_engine_startup)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -48,11 +44,15 @@ class ServeSubcommand(CLISubcommand): ...@@ -48,11 +44,15 @@ class ServeSubcommand(CLISubcommand):
if args.headless or args.api_server_count < 1: if args.headless or args.api_server_count < 1:
run_headless(args) run_headless(args)
elif args.api_server_count > 1:
run_multi_api_server(args)
else: else:
# Single API server (this process). if args.data_parallel_start_rank:
uvloop.run(run_server(args)) raise ValueError("data_parallel_start_rank is only "
"applicable in headless mode")
if args.api_server_count > 1:
run_multi_api_server(args)
else:
# Single API server (this process).
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None: def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args) validate_parsed_serve_args(args)
...@@ -121,14 +121,19 @@ def run_headless(args: argparse.Namespace): ...@@ -121,14 +121,19 @@ def run_headless(args: argparse.Namespace):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
handshake_address = get_tcp_uri(host, port)
if local_engine_count <= 0: if local_engine_count <= 0:
raise ValueError("data_parallel_size_local must be > 0 in " raise ValueError("data_parallel_size_local must be > 0 in "
"headless mode") "headless mode")
if parallel_config.data_parallel_rank is not None:
raise ValueError("data_parallel_rank is not applicable in "
"headless mode")
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
handshake_address = get_tcp_uri(host, port)
# Catch SIGTERM and SIGINT to allow graceful shutdown. # Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame): def signal_handler(signum, frame):
logger.debug("Received %d signal.", signum) logger.debug("Received %d signal.", signum)
...@@ -148,7 +153,7 @@ def run_headless(args: argparse.Namespace): ...@@ -148,7 +153,7 @@ def run_headless(args: argparse.Namespace):
start_index=args.data_parallel_start_rank, start_index=args.data_parallel_start_rank,
local_start_index=0, local_start_index=0,
vllm_config=vllm_config, vllm_config=vllm_config,
on_head_node=False, local_client=False,
handshake_address=handshake_address, handshake_address=handshake_address,
executor_class=Executor.get_class(vllm_config), executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
...@@ -192,117 +197,53 @@ def run_multi_api_server(args: argparse.Namespace): ...@@ -192,117 +197,53 @@ def run_multi_api_server(args: argparse.Namespace):
" api_server_count > 1") " api_server_count > 1")
model_config.disable_mm_preprocessor_cache = True model_config.disable_mm_preprocessor_cache = True
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_rank = parallel_config.data_parallel_rank
external_dp_lb = parallel_config.data_parallel_external_lb
assert external_dp_lb or dp_rank == 0
assert parallel_config.data_parallel_rank == 0 api_server_manager: Optional[APIServerProcessManager] = None
dp_size = parallel_config.data_parallel_size with launch_core_engines(vllm_config, executor_class, log_stats,
local_engine_count = parallel_config.data_parallel_size_local num_api_servers) as (local_engine_manager,
host = parallel_config.data_parallel_master_ip coordinator, addresses):
local_only = local_engine_count == dp_size
# Set up input and output addresses.
input_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]
output_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]
addresses = EngineZmqAddresses(
inputs=input_addresses,
outputs=output_addresses,
)
# Set up coordinator for dp > 1. # Construct common args for the APIServerProcessManager up-front.
coordinator = None api_server_manager_kwargs = dict(
stats_update_address = None
if dp_size > 1:
coordinator = DPCoordinator(parallel_config)
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
stats_update_address = coordinator.get_stats_publish_address()
logger.info("Started DP Coordinator process (PID: %d)",
coordinator.proc.pid)
if parallel_config.data_parallel_backend == "ray":
logger.info("Starting ray-based data parallel backend")
engine_actor_manager = CoreEngineActorManager(
vllm_config=vllm_config,
addresses=addresses,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
# Start API servers using the manager
api_server_manager = APIServerProcessManager(
target_server_fn=run_api_server_worker_proc, target_server_fn=run_api_server_worker_proc,
listen_address=listen_address, listen_address=listen_address,
sock=sock, sock=sock,
args=args, args=args,
num_servers=num_api_servers, num_servers=num_api_servers,
input_addresses=input_addresses, input_addresses=addresses.inputs,
output_addresses=output_addresses, output_addresses=addresses.outputs,
stats_update_address=stats_update_address) stats_update_address=coordinator.get_stats_publish_address()
if coordinator else None)
wait_for_completion_or_failure(api_server_manager=api_server_manager,
engine_manager=engine_actor_manager, # For dp ranks > 0 in external DP LB mode, we must delay the
coordinator=coordinator) # start of the API servers until the local engine is started
return # (after the launcher context manager exits),
# since we get the front-end stats update address from the coordinator
handshake_address = get_engine_client_zmq_addr( # via the handshake with the local engine.
local_only, host, parallel_config.data_parallel_rpc_port) if dp_rank == 0 or not external_dp_lb:
# Start API servers using the manager.
with zmq_socket_ctx(handshake_address, zmq.ROUTER, api_server_manager = APIServerProcessManager(
bind=True) as handshake_socket: **api_server_manager_kwargs)
# Start local engines. # Start API servers now if they weren't already started.
if not local_engine_count: if api_server_manager is None:
local_engine_manager = None api_server_manager_kwargs["stats_update_address"] = (
else: addresses.frontend_stats_publish_address)
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
handshake_address=handshake_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=0,
local_start_index=0)
# Start API servers using the manager
api_server_manager = APIServerProcessManager( api_server_manager = APIServerProcessManager(
target_server_fn=run_api_server_worker_proc, **api_server_manager_kwargs)
listen_address=listen_address,
sock=sock, # Wait for API servers
args=args, wait_for_completion_or_failure(api_server_manager=api_server_manager,
num_servers=num_api_servers, engine_manager=local_engine_manager,
input_addresses=input_addresses, coordinator=coordinator)
output_addresses=output_addresses,
stats_update_address=stats_update_address)
# Wait for engine handshakes to complete.
core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
wait_for_engine_startup(
handshake_socket,
addresses,
core_engines,
parallel_config,
vllm_config.cache_config,
local_engine_manager,
coordinator.proc if coordinator else None,
)
# Wait for API servers
wait_for_completion_or_failure(api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator)
def run_api_server_worker_proc(listen_address, def run_api_server_worker_proc(listen_address,
......
...@@ -10,7 +10,7 @@ import zmq ...@@ -10,7 +10,7 @@ import zmq
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket from vllm.utils import get_mp_context, make_zmq_socket
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
...@@ -48,20 +48,33 @@ class DPCoordinator: ...@@ -48,20 +48,33 @@ class DPCoordinator:
Engines will move into running state when receiving a new request or Engines will move into running state when receiving a new request or
START_DP_WAVE message. START_DP_WAVE message.
Note that when deployed in External LB mode, no stats will be published by
the engines and thus updates will only be sent to front-ends when the
request wave / running state changes.
""" """
def __init__(self, parallel_config: ParallelConfig): def __init__(self, parallel_config: ParallelConfig):
# Assume coordinator is colocated with front-end procs.
front_publish_address = get_open_zmq_ipc_path()
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel" assert dp_size > 1, "Coordinator only used for data parallel"
local_only = dp_size == parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip host = parallel_config.data_parallel_master_ip
back_publish_address = get_engine_client_zmq_addr(local_only, host) external_lb = parallel_config.data_parallel_external_lb
back_output_address = get_engine_client_zmq_addr(local_only, host)
# Assume coordinator is colocated with front-end procs when not in
# external DP LB mode.
front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb, host=host)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
# When in external LB mode, load stats aren't published, only changes
# to request wave / running state, so we don't need to rate-limit the
# updates to the front-end proc(s).
min_stats_update_interval_ms = 0 if external_lb else 100
context = get_mp_context() context = get_mp_context()
self.proc: multiprocessing.Process = context.Process( self.proc: multiprocessing.Process = context.Process(
...@@ -72,6 +85,7 @@ class DPCoordinator: ...@@ -72,6 +85,7 @@ class DPCoordinator:
"front_publish_address": front_publish_address, "front_publish_address": front_publish_address,
"back_output_address": back_output_address, "back_output_address": back_output_address,
"back_publish_address": back_publish_address, "back_publish_address": back_publish_address,
"min_stats_update_interval_ms": min_stats_update_interval_ms,
}, },
daemon=True) daemon=True)
self.proc.start() self.proc.start()
...@@ -100,12 +114,16 @@ class EngineState: ...@@ -100,12 +114,16 @@ class EngineState:
class CoordinatorProc: class CoordinatorProc:
def __init__(self, engine_count: int): def __init__(self,
engine_count: int,
min_stats_update_interval_ms: int = 100):
self.ctx = zmq.Context() self.ctx = zmq.Context()
self.engines = [EngineState() for _ in range(engine_count)] self.engines = [EngineState() for _ in range(engine_count)]
self.stats_update_interval_ms = min_stats_update_interval_ms
self.current_wave = 0 self.current_wave = 0
self.engines_running = False self.engines_running = False
self.stats_changed = False self.stats_changed = False
...@@ -116,8 +134,11 @@ class CoordinatorProc: ...@@ -116,8 +134,11 @@ class CoordinatorProc:
front_publish_address: str, front_publish_address: str,
back_output_address: str, back_output_address: str,
back_publish_address: str, back_publish_address: str,
min_stats_update_interval_ms: int = 100,
): ):
coordinator = CoordinatorProc(engine_count=engine_count) coordinator = CoordinatorProc(
engine_count=engine_count,
min_stats_update_interval_ms=min_stats_update_interval_ms)
try: try:
coordinator.process_input_socket( coordinator.process_input_socket(
front_publish_address, front_publish_address,
...@@ -156,9 +177,10 @@ class CoordinatorProc: ...@@ -156,9 +177,10 @@ class CoordinatorProc:
last_publish_time = 0 last_publish_time = 0
while True: while True:
elapsed = int(time.time() * 1000) - last_publish_time elapsed = int(time.time() * 1000) - last_publish_time
# Send at 100 ms interval if the stats have changed, # Send at stats_update_interval_ms interval if the stats have
# or otherwise every 3 seconds. # changed, or otherwise every 4 seconds.
wait_for = 100 if self.stats_changed else 3000 wait_for = (self.stats_update_interval_ms
if self.stats_changed else 4000)
events = poller.poll(timeout=max(0, wait_for - elapsed)) events = poller.poll(timeout=max(0, wait_for - elapsed))
if not events: if not events:
# Poller timeout - publish current stats to front-ends. # Poller timeout - publish current stats to front-ends.
...@@ -174,7 +196,7 @@ class CoordinatorProc: ...@@ -174,7 +196,7 @@ class CoordinatorProc:
if publish_front in events: if publish_front in events:
buffer = publish_front.recv() buffer = publish_front.recv()
if buffer == b'\x01': if buffer in (b'\x01', b'\x00'):
# Ignore subscription messages. # Ignore subscription messages.
continue continue
......
...@@ -34,6 +34,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler ...@@ -34,6 +34,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
...@@ -41,7 +42,6 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -41,7 +42,6 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -367,10 +367,11 @@ class EngineCoreProc(EngineCore): ...@@ -367,10 +367,11 @@ class EngineCoreProc(EngineCore):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
on_head_node: bool, local_client: bool,
handshake_address: str, handshake_address: str,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_handshake_address: Optional[str] = None,
engine_index: int = 0, engine_index: int = 0,
): ):
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
...@@ -383,12 +384,21 @@ class EngineCoreProc(EngineCore): ...@@ -383,12 +384,21 @@ class EngineCoreProc(EngineCore):
identity = self.engine_index.to_bytes(length=2, byteorder="little") identity = self.engine_index.to_bytes(length=2, byteorder="little")
self.engines_running = False self.engines_running = False
with self._perform_handshake(handshake_address, identity, on_head_node, with self._perform_handshakes(handshake_address, identity,
vllm_config) as addresses: local_client, vllm_config,
client_handshake_address) as addresses:
self.client_count = len(addresses.outputs) self.client_count = len(addresses.outputs)
# Set up data parallel environment. # Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None self.has_coordinator = addresses.coordinator_output is not None
self.frontend_stats_publish_address = (
addresses.frontend_stats_publish_address)
# Only publish request queue stats to coordinator for "internal"
# LB mode.
self.publish_dp_lb_stats = (
self.has_coordinator
and not vllm_config.parallel_config.data_parallel_external_lb)
self._init_data_parallel(vllm_config) self._init_data_parallel(vllm_config)
super().__init__(vllm_config, executor_class, log_stats, super().__init__(vllm_config, executor_class, log_stats,
...@@ -414,45 +424,102 @@ class EngineCoreProc(EngineCore): ...@@ -414,45 +424,102 @@ class EngineCoreProc(EngineCore):
self.output_thread.start() self.output_thread.start()
@contextmanager @contextmanager
def _perform_handshake( def _perform_handshakes(
self, handshake_address: str, identity: bytes, on_head_node: bool, self,
vllm_config: VllmConfig handshake_address: str,
identity: bytes,
local_client: bool,
vllm_config: VllmConfig,
client_handshake_address: Optional[str],
) -> Generator[EngineZmqAddresses, None, None]: ) -> Generator[EngineZmqAddresses, None, None]:
"""
Perform startup handshakes.
For DP=1 or offline mode, this is with the colocated front-end process.
For DP>1 with internal loadbalancing this is with the shared front-end
process which may reside on a different node.
For DP>1 with external loadbalancing, two handshakes are performed:
- With the rank 0 front-end process which retrieves the
DP Coordinator ZMQ addresses and DP process group address.
- With the colocated front-end process which retrieves the
client input/output socket addresses.
with the exception of the rank 0 engine itself which doesn't require
the second handshake.
Here, "front-end" process can mean the process containing the engine
core client (which is the API server process in the case the API
server is not scaled out), OR the launcher process running the
run_multi_api_server() function in serve.py.
"""
input_ctx = zmq.Context() input_ctx = zmq.Context()
with make_zmq_socket(input_ctx, is_local = local_client and client_handshake_address is None
handshake = self._perform_handshake(input_ctx, handshake_address,
identity, is_local, vllm_config,
vllm_config.parallel_config)
if client_handshake_address is None:
with handshake as addresses:
yield addresses
else:
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, local_client,
vllm_config)
with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs
addresses.outputs = client_addresses.outputs
yield addresses
# Update config which may have changed from the handshake
vllm_config.__post_init__()
@contextmanager
def _perform_handshake(
self,
ctx: zmq.Context,
handshake_address: str,
identity: bytes,
local_client: bool,
vllm_config: VllmConfig,
parallel_config_to_update: Optional[ParallelConfig] = None,
) -> Generator[EngineZmqAddresses, None, None]:
with make_zmq_socket(ctx,
handshake_address, handshake_address,
zmq.DEALER, zmq.DEALER,
identity=identity, identity=identity,
linger=5000, linger=5000,
bind=False) as handshake_socket: bind=False) as handshake_socket:
# Register engine with front-end. # Register engine with front-end.
addresses = self.startup_handshake(handshake_socket, on_head_node, addresses = self.startup_handshake(handshake_socket, local_client,
vllm_config.parallel_config) parallel_config_to_update)
# Update config which may have changed from the handshake
vllm_config.__post_init__()
yield addresses yield addresses
# Send ready message. # Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
# We pass back the coordinator stats update address here for the
# external LB case for our colocated front-end to use (coordinator
# only runs with rank 0).
dp_stats_address = self.frontend_stats_publish_address
handshake_socket.send( handshake_socket.send(
msgspec.msgpack.encode({ msgspec.msgpack.encode({
"status": "READY", "status": "READY",
"local": on_head_node, "local": local_client,
"num_gpu_blocks": num_gpu_blocks, "num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
})) }))
@staticmethod @staticmethod
def startup_handshake( def startup_handshake(
handshake_socket: zmq.Socket, on_head_node: bool, handshake_socket: zmq.Socket,
parallel_config: ParallelConfig) -> EngineZmqAddresses: local_client: bool,
parallel_config: Optional[ParallelConfig] = None,
) -> EngineZmqAddresses:
# Send registration message. # Send registration message.
handshake_socket.send( handshake_socket.send(
msgspec.msgpack.encode({ msgspec.msgpack.encode({
"status": "HELLO", "status": "HELLO",
"local": on_head_node, "local": local_client,
})) }))
# Receive initialization message. # Receive initialization message.
...@@ -466,9 +533,9 @@ class EngineCoreProc(EngineCore): ...@@ -466,9 +533,9 @@ class EngineCoreProc(EngineCore):
init_bytes, type=EngineHandshakeMetadata) init_bytes, type=EngineHandshakeMetadata)
logger.debug("Received init message: %s", init_message) logger.debug("Received init message: %s", init_message)
received_parallel_config = init_message.parallel_config if parallel_config is not None:
for key, value in received_parallel_config.items(): for key, value in init_message.parallel_config.items():
setattr(parallel_config, key, value) setattr(parallel_config, key, value)
return init_message.addresses return init_message.addresses
...@@ -749,12 +816,12 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -749,12 +816,12 @@ class DPEngineCoreProc(EngineCoreProc):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
on_head_node: bool, local_client: bool,
handshake_address: str, handshake_address: str,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_handshake_address: Optional[str] = None,
): ):
self._decorate_logs() self._decorate_logs()
# Counts forward-passes of the model so that we can synchronize # Counts forward-passes of the model so that we can synchronize
...@@ -765,8 +832,9 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -765,8 +832,9 @@ class DPEngineCoreProc(EngineCoreProc):
# Initialize the engine. # Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, handshake_address, super().__init__(vllm_config, local_client, handshake_address,
executor_class, log_stats, dp_rank) executor_class, log_stats, client_handshake_address,
dp_rank)
def _decorate_logs(self): def _decorate_logs(self):
# Add process-specific prefix to stdout and stderr before # Add process-specific prefix to stdout and stderr before
...@@ -799,10 +867,18 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -799,10 +867,18 @@ class DPEngineCoreProc(EngineCoreProc):
from vllm.platforms import current_platform from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var device_control_env_var = current_platform.device_control_env_var
world_size = vllm_config.parallel_config.world_size world_size = vllm_config.parallel_config.world_size
os.environ[device_control_env_var] = ",".join( # Set CUDA_VISIBLE_DEVICES or equivalent.
str(current_platform.device_id_to_physical_device_id(i)) try:
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * os.environ[device_control_env_var] = ",".join(
world_size)) str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank *
world_size, (local_dp_rank + 1) * world_size))
except IndexError as e:
raise Exception(
f"Error setting {device_control_env_var}: "
f"local range: [{local_dp_rank * world_size}, "
f"{(local_dp_rank + 1) * world_size}) "
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
...@@ -839,7 +915,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -839,7 +915,7 @@ class DPEngineCoreProc(EngineCoreProc):
super()._handle_client_request(request_type, request) super()._handle_client_request(request_type, request)
def _maybe_publish_request_counts(self): def _maybe_publish_request_counts(self):
if not self.has_coordinator: if not self.publish_dp_lb_stats:
return return
# Publish our request counts (if they've changed). # Publish our request counts (if they've changed).
...@@ -892,9 +968,9 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -892,9 +968,9 @@ class DPEngineCoreProc(EngineCoreProc):
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 24 steps. # Optimization - only perform finish-sync all-reduce every 32 steps.
self.counter += 1 self.counter += 1
if self.counter != 24: if self.counter != 32:
return True return True
self.counter = 0 self.counter = 0
...@@ -910,7 +986,7 @@ class DPEngineCoreActor(DPEngineCoreProc): ...@@ -910,7 +986,7 @@ class DPEngineCoreActor(DPEngineCoreProc):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
on_head_node: bool, local_client: bool,
addresses: EngineZmqAddresses, addresses: EngineZmqAddresses,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
...@@ -927,15 +1003,16 @@ class DPEngineCoreActor(DPEngineCoreProc): ...@@ -927,15 +1003,16 @@ class DPEngineCoreActor(DPEngineCoreProc):
# data parallel groups. # data parallel groups.
del os.environ['CUDA_VISIBLE_DEVICES'] del os.environ['CUDA_VISIBLE_DEVICES']
super().__init__(vllm_config, on_head_node, "", executor_class, super().__init__(vllm_config, local_client, "", executor_class,
log_stats) log_stats)
def _decorate_logs(self): def _decorate_logs(self):
pass pass
@contextmanager @contextmanager
def _perform_handshake(self, handshake_address: str, identity: bytes, def _perform_handshakes(self, handshake_address: str, identity: bytes,
on_head_node: bool, vllm_config: VllmConfig): local_client: bool, vllm_config: VllmConfig,
client_handshake_address: Optional[str]):
""" """
For Ray, we don't need to actually perform handshake. For Ray, we don't need to actually perform handshake.
All addresses information is known before the actor creation. All addresses information is known before the actor creation.
......
This diff is collapsed.
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse import argparse
import multiprocessing import multiprocessing
import time import time
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from multiprocessing import connection
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload) Union, overload)
import msgspec
import torch import torch
import zmq
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path, from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
get_tcp_uri, kill_process_tree) kill_process_tree)
from vllm.v1.executor.abstract import Executor
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
logger = init_logger(__name__) logger = init_logger(__name__)
T = TypeVar("T") T = TypeVar("T")
STARTUP_POLL_PERIOD_MS = 10000
class ConstantList(Generic[T], Sequence): class ConstantList(Generic[T], Sequence):
...@@ -111,47 +102,16 @@ class ConstantList(Generic[T], Sequence): ...@@ -111,47 +102,16 @@ class ConstantList(Generic[T], Sequence):
def get_engine_client_zmq_addr(local_only: bool, def get_engine_client_zmq_addr(local_only: bool,
host: str, host: str,
port: int = 0) -> str: port: int = 0) -> str:
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( """Assign a new ZMQ socket address.
host, port or get_open_port()))
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
If local_only is True, participants are colocated and so a unique IPC
address will be returned.
@dataclass Otherwise, the provided host and port will be used to construct a TCP
class EngineZmqAddresses: address (port == 0 means assign an available port)."""
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
@dataclass host, port or get_open_port()))
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
class APIServerProcessManager: class APIServerProcessManager:
...@@ -219,339 +179,10 @@ class APIServerProcessManager: ...@@ -219,339 +179,10 @@ class APIServerProcessManager:
self._finalizer() self._finalizer()
class CoreEngineProcManager:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
"""
def __init__(
self,
target_fn: Callable,
local_engine_count: int,
start_index: int,
local_start_index: int,
vllm_config: VllmConfig,
on_head_node: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
):
context = get_mp_context()
common_kwargs = {
"vllm_config": vllm_config,
"on_head_node": on_head_node,
"handshake_address": handshake_address,
"executor_class": executor_class,
"log_stats": log_stats,
}
self.processes: list[BaseProcess] = []
for index in range(local_engine_count):
local_index = local_start_index + index
global_index = start_index + index
# Start EngineCore in background process.
self.processes.append(
context.Process(target=target_fn,
name=f"EngineCore_{global_index}",
kwargs=common_kwargs | {
"dp_rank": global_index,
"local_dp_rank": local_index,
}))
self._finalizer = weakref.finalize(self, shutdown, self.processes)
try:
for proc in self.processes:
proc.start()
finally:
# Kill other procs if not all are running.
if self.finished_procs():
self.close()
def close(self):
"""Shutdown all procs."""
self._finalizer()
def join_first(self):
"""Wait for any process to exit."""
connection.wait(proc.sentinel for proc in self.processes)
def sentinels(self) -> list:
return [proc.sentinel for proc in self.processes]
def finished_procs(self) -> dict[str, int]:
"""Returns dict of proc name -> exit code for any finished procs."""
return {
proc.name: proc.exitcode
for proc in self.processes if proc.exitcode is not None
}
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown
of core engine Ray actors used by the AsyncLLM and LLMEngine.
Different from CoreEngineProcManager, this class manages
core engines for both local and remote nodes.
"""
def __init__(
self,
vllm_config: VllmConfig,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
placement_groups: Optional[list["PlacementGroup"]] = None,
local_dp_ranks: Optional[list[int]] = None,
):
import copy
import ray
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy)
from vllm.v1.engine.core import DPEngineCoreActor
self.local_engine_actors: list[ray.ActorHandle] = []
self.remote_engine_actors: list[ray.ActorHandle] = []
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size
if ray.is_initialized():
logger.info(
"Ray is already initialized. Skipping Ray initialization.")
else:
ray.init()
if placement_groups is not None:
assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if "
"placement_groups is provided")
assert len(placement_groups) == len(local_dp_ranks), (
"placement_groups and local_dp_ranks must "
"have the same length")
logger.info("Using provided placement groups")
# TODO(rui): validate passed-in placement groups
self.created_placement_groups = []
else:
placement_groups, local_dp_ranks = \
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
self.created_placement_groups = placement_groups
assert len(placement_groups) == dp_size, (
"Number of placement groups must match data parallel size")
refs = []
for index in range(dp_size):
local_index = local_dp_ranks[index]
dp_vllm_config = copy.deepcopy(vllm_config)
pg = placement_groups[index]
dp_vllm_config.parallel_config.placement_group = pg
on_head_node = index < local_engine_count
actor = ray.remote(DPEngineCoreActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=world_size,
)).remote(vllm_config=dp_vllm_config,
executor_class=executor_class,
log_stats=log_stats,
on_head_node=on_head_node,
addresses=addresses,
dp_rank=index,
local_dp_rank=local_index)
if on_head_node:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
refs.append(actor.wait_for_init.remote())
ray.get(refs)
self.run_refs = []
for actor in self.local_engine_actors + self.remote_engine_actors:
self.run_refs.append(actor.run.remote())
@staticmethod
def create_dp_placement_groups(
vllm_config: VllmConfig
) -> tuple[list["PlacementGroup"], list[int]]:
import ray
from ray._private.state import available_resources_per_node
from ray.util.state import list_nodes
logger.info("Creating placement groups for data parallel")
dp_master_ip = \
vllm_config.parallel_config.data_parallel_master_ip
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
nodes = list_nodes()
nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
"The first node must be the head node")
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
"There can only be one head node")
available_resources = available_resources_per_node()
world_size = vllm_config.parallel_config.world_size
placement_groups: list[PlacementGroup] = []
local_dp_ranks: list[int] = []
for node in nodes:
node_ip = node.node_ip
node_resources = available_resources[node.node_id]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
available_engine_count = int(node_resources["GPU"]) // world_size
if node_ip == dp_master_ip:
assert available_engine_count >= local_engine_count, (
"Not enough resources to allocate DP ranks "
f"on DP master node {node_ip}")
for i in range(local_engine_count):
bundles = [{
"GPU": 1.0,
"node:" + dp_master_ip: 0.001
}] * world_size + [{
"CPU": 1.0
}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
else:
for i in range(available_engine_count):
if len(placement_groups) == dp_size:
break
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
return placement_groups, local_dp_ranks
def get_run_refs(self):
return self.run_refs
def close(self):
import ray
for actor in self.local_engine_actors + self.remote_engine_actors:
ray.kill(actor)
for pg in self.created_placement_groups:
ray.util.remove_placement_group(pg)
def wait_for_engine_startup(
handshake_socket: zmq.Socket,
addresses: EngineZmqAddresses,
core_engines: list[CoreEngine],
parallel_config: ParallelConfig,
cache_config: CacheConfig,
proc_manager: Optional[CoreEngineProcManager],
coord_process: Optional[Process],
):
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
if coord_process is not None:
poller.register(coord_process.sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != handshake_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
if coord_process is not None and coord_process.exitcode is not None:
finished[coord_process.name] = coord_process.exitcode
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, "little")
engine = next((e for e in core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata(
addresses=addresses,
parallel_config={
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
}))
handshake_socket.send_multipart((eng_identity, init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)
def wait_for_completion_or_failure( def wait_for_completion_or_failure(
api_server_manager: APIServerProcessManager, api_server_manager: APIServerProcessManager,
engine_manager: Optional[Union[CoreEngineProcManager, engine_manager: Optional[Union["CoreEngineProcManager",
CoreEngineActorManager]] = None, "CoreEngineActorManager"]] = None,
coordinator: Optional["DPCoordinator"] = None) -> None: coordinator: Optional["DPCoordinator"] = None) -> None:
"""Wait for all processes to complete or detect if any fail. """Wait for all processes to complete or detect if any fail.
...@@ -565,6 +196,9 @@ def wait_for_completion_or_failure( ...@@ -565,6 +196,9 @@ def wait_for_completion_or_failure(
coordinator: The coordinator for data parallel. coordinator: The coordinator for data parallel.
""" """
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
try: try:
logger.info("Waiting for API servers to complete ...") logger.info("Waiting for API servers to complete ...")
# Create a mapping of sentinels to their corresponding processes # Create a mapping of sentinels to their corresponding processes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment