"vllm/vscode:/vscode.git/clone" did not exist on "45f526d65237d9073a5f3be166b306580687f210"
Unverified Commit 2dbe8c07 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Perf] API-server scaleout with many-to-many server-engine comms (#17546)

parent 84ec470f
...@@ -7,6 +7,7 @@ import threading ...@@ -7,6 +7,7 @@ import threading
import time import time
from collections import deque from collections import deque
from concurrent.futures import Future from concurrent.futures import Future
from contextlib import ExitStack
from inspect import isclass, signature from inspect import isclass, signature
from logging import DEBUG from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
...@@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception ...@@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs) unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
...@@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, ...@@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -211,7 +214,7 @@ class EngineCore: ...@@ -211,7 +214,7 @@ class EngineCore:
# Re-raise exception # Re-raise exception
raise err raise err
def step(self) -> tuple[EngineCoreOutputs, bool]: def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output. """Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model Returns tuple of outputs and a flag indicating whether the model
...@@ -221,10 +224,7 @@ class EngineCore: ...@@ -221,10 +224,7 @@ class EngineCore:
# Check for any requests remaining in the scheduler - unfinished, # Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch. # or finished and not yet removed from the batch.
if not self.scheduler.has_requests(): if not self.scheduler.has_requests():
return EngineCoreOutputs( return {}, False
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
), False
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
model_output = self.execute_model(scheduler_output) model_output = self.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
...@@ -234,7 +234,7 @@ class EngineCore: ...@@ -234,7 +234,7 @@ class EngineCore:
scheduler_output.total_num_scheduled_tokens > 0) scheduler_output.total_num_scheduled_tokens > 0)
def step_with_batch_queue( def step_with_batch_queue(
self) -> tuple[Optional[EngineCoreOutputs], bool]: self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue. """Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned. Note that if nothing to output in this step, None is returned.
...@@ -276,8 +276,8 @@ class EngineCore: ...@@ -276,8 +276,8 @@ class EngineCore:
# Blocking until the first result is available. # Blocking until the first result is available.
model_output = future.result() model_output = future.result()
self.batch_queue.task_done() self.batch_queue.task_done()
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = (self.scheduler.update_from_output(
scheduler_output, model_output) scheduler_output, model_output))
return engine_core_outputs, scheduled_batch return engine_core_outputs, scheduled_batch
...@@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore): ...@@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore):
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
on_head_node: bool, on_head_node: bool,
input_address: str, handshake_address: str,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
engine_index: int = 0, engine_index: int = 0,
...@@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore): ...@@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore):
# Create input socket. # Create input socket.
input_ctx = zmq.Context() input_ctx = zmq.Context()
identity = engine_index.to_bytes(length=2, byteorder="little") identity = engine_index.to_bytes(length=2, byteorder="little")
input_socket = make_zmq_socket(input_ctx, with make_zmq_socket(input_ctx,
input_address, handshake_address,
zmq.DEALER, zmq.DEALER,
identity=identity, identity=identity,
bind=False) linger=5000,
try: bind=False) as handshake_socket:
# Register engine with front-end. # Register engine with front-end.
output_address = self.startup_handshake( addresses = self.startup_handshake(handshake_socket, on_head_node,
input_socket, on_head_node, vllm_config.parallel_config) vllm_config.parallel_config)
self.client_count = len(addresses.outputs)
# Update config which may have changed from the handshake. # Update config which may have changed from the handshake.
vllm_config.__post_init__() vllm_config.__post_init__()
# Set up data parallel environment. # Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
self._init_data_parallel(vllm_config) self._init_data_parallel(vllm_config)
# Initialize engine core and model. # Initialize engine core and model.
super().__init__(vllm_config, executor_class, log_stats, super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback) executor_fail_callback)
self.engine_index = engine_index
self.step_fn = (self.step if self.batch_queue is None else self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue) self.step_with_batch_queue)
self.engines_running = False self.engines_running = False
self.last_counts = (0, 0)
# Send ready message. # Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
input_socket.send( handshake_socket.send(
msgspec.msgpack.encode({ msgspec.msgpack.encode({
"status": "READY", "status": "READY",
"local": on_head_node, "local": on_head_node,
"num_gpu_blocks": num_gpu_blocks, "num_gpu_blocks": num_gpu_blocks,
})) }))
# Background Threads and Queues for IO. These enable us to # Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL, # overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the # and to overlap some serialization/deserialization with the
# model forward pass. # model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue. # Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue self.input_queue = input_queue
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
threading.Thread(target=self.process_input_socket, bytes]]()
args=(input_socket, ), threading.Thread(target=self.process_input_sockets,
daemon=True).start() args=(addresses.inputs, addresses.coordinator_input,
input_socket = None identity),
self.output_thread = threading.Thread( daemon=True).start()
target=self.process_output_socket, self.output_thread = threading.Thread(
args=(output_address, engine_index), target=self.process_output_sockets,
daemon=True) args=(addresses.outputs, addresses.coordinator_output,
self.output_thread.start() engine_index),
finally: daemon=True)
if input_socket is not None: self.output_thread.start()
input_socket.close(linger=0)
@staticmethod @staticmethod
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, def startup_handshake(
parallel_config: ParallelConfig) -> str: handshake_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> EngineZmqAddresses:
# Send registration message. # Send registration message.
input_socket.send( handshake_socket.send(
msgspec.msgpack.encode({ msgspec.msgpack.encode({
"status": "HELLO", "status": "HELLO",
"local": on_head_node, "local": on_head_node,
...@@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore): ...@@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore):
# Receive initialization message. # Receive initialization message.
logger.info("Waiting for init message from front-end.") logger.info("Waiting for init message from front-end.")
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
raise RuntimeError("Did not receive response from front-end " raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} " f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes") f"minutes")
init_bytes = input_socket.recv() init_bytes = handshake_socket.recv()
init_message = msgspec.msgpack.decode(init_bytes) init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
init_bytes, type=EngineHandshakeMetadata)
logger.debug("Received init message: %s", init_message) logger.debug("Received init message: %s", init_message)
output_socket_address = init_message["output_socket_address"] received_parallel_config = init_message.parallel_config
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config = init_message["parallel_config"]
for key, value in received_parallel_config.items(): for key, value in received_parallel_config.items():
setattr(parallel_config, key, value) setattr(parallel_config, key, value)
return output_socket_address return init_message.addresses
@staticmethod @staticmethod
def run_engine_core(*args, def run_engine_core(*args,
...@@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore): ...@@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed.""" """Exits when an engine step needs to be performed."""
waited = False waited = False
while not self.engines_running and not (self.scheduler.has_requests()): while not self.engines_running and not self.scheduler.has_requests():
if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.") logger.debug("EngineCore waiting for work.")
waited = True waited = True
...@@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore): ...@@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore):
# Step the engine core. # Step the engine core.
outputs, model_executed = self.step_fn() outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue. # Put EngineCoreOutputs into the output queue.
if outputs is not None: for output in (outputs.items() if outputs else ()):
self.output_queue.put_nowait(outputs) self.output_queue.put_nowait(output)
return model_executed return model_executed
...@@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore): ...@@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore):
elif request_type == EngineCoreRequestType.ABORT: elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request) self.abort_requests(request)
elif request_type == EngineCoreRequestType.UTILITY: elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request client_idx, call_id, method_name, args = request
output = UtilityOutput(call_id) output = UtilityOutput(call_id)
try: try:
method = getattr(self, method_name) method = getattr(self, method_name)
...@@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore): ...@@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore):
output.failure_message = (f"Call to {method_name} method" output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}") f" failed: {str(e)}")
self.output_queue.put_nowait( self.output_queue.put_nowait(
EngineCoreOutputs(utility_output=output)) (client_idx, EngineCoreOutputs(utility_output=output)))
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.") raise RuntimeError("Executor failed.")
else: else:
...@@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore): ...@@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore):
logger.fatal("vLLM shutdown signal from EngineCore failed " logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.") "to send. Please report this issue.")
def process_input_socket(self, input_socket: zmq.Socket): def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes):
"""Input socket IO thread.""" """Input socket IO thread."""
# Msgpack serialization decoding. # Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest) add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder() generic_decoder = MsgpackDecoder()
while True: with ExitStack() as stack, zmq.Context() as ctx:
# (RequestType, RequestData) input_sockets = [
type_frame, *data_frames = input_socket.recv_multipart(copy=False) stack.enter_context(
request_type = EngineCoreRequestType(bytes(type_frame.buffer)) make_zmq_socket(ctx,
input_address,
# Deserialize the request data. zmq.DEALER,
decoder = add_request_decoder if ( identity=identity,
request_type == EngineCoreRequestType.ADD) else generic_decoder bind=False))
request = decoder.decode(data_frames) for input_address in input_addresses
]
# Push to input queue for core busy loop. if coord_input_address is None:
self.input_queue.put_nowait((request_type, request)) coord_socket = None
else:
coord_socket = stack.enter_context(
make_zmq_socket(ctx,
coord_input_address,
zmq.XSUB,
identity=identity,
bind=False))
# Send subscription message to coordinator.
coord_socket.send(b'\x01')
# Register sockets with poller.
poller = zmq.Poller()
for input_socket in input_sockets:
# Send initial message to each input socket - this is required
# before the front-end ROUTER socket can send input messages
# back to us.
input_socket.send(b'')
poller.register(input_socket, zmq.POLLIN)
if coord_socket is not None:
poller.register(coord_socket, zmq.POLLIN)
def process_output_socket(self, output_path: str, engine_index: int): while True:
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(
copy=False)
request_type = EngineCoreRequestType(
bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str],
engine_index: int):
"""Output socket IO thread.""" """Output socket IO thread."""
# Msgpack serialization encoding. # Msgpack serialization encoding.
...@@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore): ...@@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore):
# We must set linger to ensure the ENGINE_CORE_DEAD # We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket. # message is sent prior to closing the socket.
with zmq_socket_ctx(output_path, zmq.constants.PUSH, with ExitStack() as stack, zmq.Context() as ctx:
linger=4000) as socket: sockets = [
stack.enter_context(
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
for output_path in output_paths
]
coord_socket = stack.enter_context(
make_zmq_socket(
ctx, coord_output_path, zmq.PUSH, bind=False,
linger=4000)) if coord_output_path is not None else None
max_reuse_bufs = len(sockets) + 1
while True: while True:
outputs = self.output_queue.get() output = self.output_queue.get()
if outputs == EngineCoreProc.ENGINE_CORE_DEAD: if output == EngineCoreProc.ENGINE_CORE_DEAD:
socket.send(outputs, copy=False) for socket in sockets:
socket.send(output)
break break
assert not isinstance(outputs, bytes) assert not isinstance(output, bytes)
client_index, outputs = output
outputs.engine_index = engine_index outputs.engine_index = engine_index
if client_index == -1:
# Don't reuse buffer for coordinator message
# which will be very small.
assert coord_socket is not None
coord_socket.send_multipart(encoder.encode(outputs))
continue
# Reclaim buffers that zmq is finished with. # Reclaim buffers that zmq is finished with.
while pending and pending[-1][0].done: while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2]) reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer) buffers = encoder.encode_into(outputs, buffer)
tracker = socket.send_multipart(buffers, tracker = sockets[client_index].send_multipart(buffers,
copy=False, copy=False,
track=True) track=True)
if not tracker.done: if not tracker.done:
ref = outputs if len(buffers) > 1 else None ref = outputs if len(buffers) > 1 else None
pending.appendleft((tracker, ref, buffer)) pending.appendleft((tracker, ref, buffer))
elif len(reuse_buffers) < 2: elif len(reuse_buffers) < max_reuse_bufs:
# Keep at most 2 buffers to reuse. # Limit the number of buffers to reuse.
reuse_buffers.append(buffer) reuse_buffers.append(buffer)
...@@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc):
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
on_head_node: bool, on_head_node: bool,
input_address: str, handshake_address: str,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
): ):
...@@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc):
# Counts forward-passes of the model so that we can synchronize # Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps. # finished with DP peers every N steps.
self.counter = 0 self.counter = 0
self.current_wave = 0
# Initialize the engine. # Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, input_address, super().__init__(vllm_config, on_head_node, handshake_address,
executor_class, log_stats, dp_rank) executor_class, log_stats, dp_rank)
def _init_data_parallel(self, vllm_config: VllmConfig): def _init_data_parallel(self, vllm_config: VllmConfig):
...@@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc):
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0
def shutdown(self): def shutdown(self):
super().shutdown() super().shutdown()
...@@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc):
stateless_destroy_torch_distributed_process_group(dp_group) stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: EngineCoreRequest): def add_request(self, request: EngineCoreRequest):
if request.current_wave != self.current_wave: if self.has_coordinator and request.current_wave != self.current_wave:
if request.current_wave > self.current_wave: if request.current_wave > self.current_wave:
self.current_wave = request.current_wave self.current_wave = request.current_wave
elif not self.engines_running: elif not self.engines_running:
# Request received for an already-completed wave, notify # Request received for an already-completed wave, notify
# front-end that we need to start the next one. # front-end that we need to start the next one.
self.output_queue.put_nowait( self.output_queue.put_nowait(
EngineCoreOutputs(start_wave=self.current_wave)) (-1, EngineCoreOutputs(start_wave=self.current_wave)))
super().add_request(request) super().add_request(request)
def _handle_client_request(self, request_type: EngineCoreRequestType, def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None: request: Any) -> None:
if request_type == EngineCoreRequestType.START_DP_WAVE: if request_type == EngineCoreRequestType.START_DP_WAVE:
new_wave: int = request new_wave, exclude_eng_index = request
if new_wave >= self.current_wave: if exclude_eng_index != self.engine_index and (
new_wave >= self.current_wave):
self.current_wave = new_wave self.current_wave = new_wave
if not self.engines_running: if not self.engines_running:
logger.debug("EngineCore starting idle loop for wave %d.", logger.debug("EngineCore starting idle loop for wave %d.",
...@@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc):
else: else:
super()._handle_client_request(request_type, request) super()._handle_client_request(request_type, request)
def _maybe_publish_request_counts(self):
if not self.has_coordinator:
return
# Publish our request counts (if they've changed).
counts = self.scheduler.get_request_counts()
if counts != self.last_counts:
self.last_counts = counts
stats = SchedulerStats(*counts)
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(scheduler_stats=stats)))
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case.""" """Core busy loop of the EngineCore for data parallel case."""
...@@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc):
# 2) Step the engine core. # 2) Step the engine core.
executed = self._process_engine_step() executed = self._process_engine_step()
self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests() local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed: if not executed:
if not local_unfinished_reqs and not self.engines_running: if not local_unfinished_reqs and not self.engines_running:
...@@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc):
logger.debug("Wave %d finished, pausing engine loop.", logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave) self.current_wave)
self.output_queue.put_nowait( self.output_queue.put_nowait(
EngineCoreOutputs(wave_complete=self.current_wave)) (-1,
EngineCoreOutputs(wave_complete=self.current_wave)))
self.current_wave += 1 self.current_wave += 1
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import asyncio import asyncio
import contextlib import contextlib
import queue import queue
import sys
import uuid import uuid
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -9,26 +10,28 @@ from collections import deque ...@@ -9,26 +10,28 @@ from collections import deque
from collections.abc import Awaitable, Sequence from collections.abc import Awaitable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
from threading import Thread from threading import Thread
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
import msgspec import msgspec.msgpack
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from vllm.config import ParallelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_port, get_open_zmq_inproc_path, from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket,
get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) zmq_socket_ctx)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import CoreEngineProcManager from vllm.v1.utils import (CoreEngine, CoreEngineProcManager,
EngineZmqAddresses, get_engine_client_zmq_addr,
wait_for_engine_startup)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]] ...@@ -36,8 +39,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc _R = TypeVar('_R') # Return type for collective_rpc
STARTUP_POLL_PERIOD_MS = 10000
class EngineCoreClient(ABC): class EngineCoreClient(ABC):
""" """
...@@ -207,7 +208,7 @@ class InprocClient(EngineCoreClient): ...@@ -207,7 +208,7 @@ class InprocClient(EngineCoreClient):
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
outputs, _ = self.engine_core.step() outputs, _ = self.engine_core.step()
return outputs return outputs.get(0) or EngineCoreOutputs()
def add_request(self, request: EngineCoreRequest) -> None: def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request) self.engine_core.add_request(request)
...@@ -266,24 +267,6 @@ class InprocClient(EngineCoreClient): ...@@ -266,24 +267,6 @@ class InprocClient(EngineCoreClient):
return self.engine_core.collective_rpc(method, timeout, args, kwargs) return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(length=2, byteorder="little")
self.state = CoreEngineState.NEW
self.num_reqs_in_flight = 0
@dataclass @dataclass
class BackgroundResources: class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding """Used as a finalizer for clean shutdown, avoiding
...@@ -291,9 +274,12 @@ class BackgroundResources: ...@@ -291,9 +274,12 @@ class BackgroundResources:
ctx: Union[zmq.Context] ctx: Union[zmq.Context]
local_engine_manager: Optional[CoreEngineProcManager] = None local_engine_manager: Optional[CoreEngineProcManager] = None
coordinator: Optional[DPCoordinator] = None
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
first_req_send_socket: Optional[zmq.asyncio.Socket] = None
output_queue_task: Optional[asyncio.Task] = None output_queue_task: Optional[asyncio.Task] = None
stats_update_task: Optional[asyncio.Task] = None
shutdown_path: Optional[str] = None shutdown_path: Optional[str] = None
# Set if any of the engines are dead. Here so that the output # Set if any of the engines are dead. Here so that the output
...@@ -306,16 +292,21 @@ class BackgroundResources: ...@@ -306,16 +292,21 @@ class BackgroundResources:
self.engine_dead = True self.engine_dead = True
if self.local_engine_manager is not None: if self.local_engine_manager is not None:
self.local_engine_manager.close() self.local_engine_manager.close()
if self.coordinator is not None:
self.coordinator.close()
if self.output_queue_task is not None: if self.output_queue_task is not None:
self.output_queue_task.cancel() self.output_queue_task.cancel()
if self.stats_update_task is not None:
self.stats_update_task.cancel()
# ZMQ context termination can hang if the sockets # ZMQ context termination can hang if the sockets
# aren't explicitly closed first. # aren't explicitly closed first.
if self.output_socket is not None: for socket in (self.output_socket, self.input_socket,
self.output_socket.close(linger=0) self.first_req_send_socket):
if self.input_socket is not None: if socket is not None:
self.input_socket.close(linger=0) socket.close(linger=0)
if self.shutdown_path is not None: if self.shutdown_path is not None:
# We must ensure that the sync output socket is # We must ensure that the sync output socket is
# closed cleanly in its own thread. # closed cleanly in its own thread.
...@@ -350,6 +341,7 @@ class MPClient(EngineCoreClient): ...@@ -350,6 +341,7 @@ class MPClient(EngineCoreClient):
vllm_config: VllmConfig, vllm_config: VllmConfig,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
# Serialization setup. # Serialization setup.
...@@ -369,8 +361,8 @@ class MPClient(EngineCoreClient): ...@@ -369,8 +361,8 @@ class MPClient(EngineCoreClient):
try: try:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local local_engine_count = parallel_config.data_parallel_size_local
start_index = parallel_config.data_parallel_rank
local_start_index = parallel_config.data_parallel_rank_local local_start_index = parallel_config.data_parallel_rank_local
dp_size = parallel_config.data_parallel_size
# SPMD mode is where there is an LLM instance per DP rank and # SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see # one core engine per LLM, see
...@@ -382,42 +374,53 @@ class MPClient(EngineCoreClient): ...@@ -382,42 +374,53 @@ class MPClient(EngineCoreClient):
CoreEngine(index=local_start_index, local=True) CoreEngine(index=local_start_index, local=True)
] ]
else: else:
assert start_index == 0 assert parallel_config.data_parallel_rank == 0
local_start_index = 0 local_start_index = 0
self.core_engines = [ self.core_engines = [
CoreEngine(index=i, local=(i < local_engine_count)) CoreEngine(index=i, local=(i < local_engine_count))
for i in range(parallel_config.data_parallel_size) for i in range(dp_size)
] ]
input_address, output_address = self._get_zmq_addresses( local_only = spmd_mode or local_engine_count == dp_size
parallel_config, spmd_mode)
self.stats_update_address: Optional[str] = None
if client_addresses is not None:
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get(
"stats_update_address")
else:
host = parallel_config.data_parallel_master_ip
input_address = get_engine_client_zmq_addr(local_only, host)
output_address = get_engine_client_zmq_addr(local_only, host)
# Create input and output sockets. # Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket( self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True) self.ctx, input_address, zmq.ROUTER, bind=True)
self.resources.output_socket = make_zmq_socket( self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.constants.PULL) self.ctx, output_address, zmq.PULL)
# Start local engines.
if local_engine_count: if client_addresses is None:
# In server mode, start_index and local_start_index will self._init_engines_direct(vllm_config, local_only,
# both be 0. local_start_index, input_address,
self.resources.local_engine_manager = CoreEngineProcManager( output_address, executor_class,
EngineCoreProc.run_engine_core, log_stats)
vllm_config=vllm_config, coordinator = self.resources.coordinator
executor_class=executor_class, if coordinator:
log_stats=log_stats, self.stats_update_address = (
input_address=input_address, coordinator.get_stats_publish_address())
on_head_node=True,
local_engine_count=local_engine_count, # Wait for ready messages from each engine on the input socket.
start_index=start_index, identities = set(e.identity for e in self.core_engines)
local_start_index=local_start_index) sync_input_socket = zmq.Socket.shadow(self.input_socket)
while identities:
if not sync_input_socket.poll(timeout=600_000):
raise TimeoutError("Timed out waiting for engines to send"
"initial message on input socket.")
identity, _ = sync_input_socket.recv_multipart()
identities.remove(identity)
self.core_engine = self.core_engines[0] self.core_engine = self.core_engines[0]
# Wait for engine core process(es) to start.
self._wait_for_engine_startup(output_address, parallel_config)
self.utility_results: dict[int, AnyFuture] = {} self.utility_results: dict[int, AnyFuture] = {}
# Request objects which may contain pytorch-allocated tensors # Request objects which may contain pytorch-allocated tensors
...@@ -430,116 +433,67 @@ class MPClient(EngineCoreClient): ...@@ -430,116 +433,67 @@ class MPClient(EngineCoreClient):
if not success: if not success:
self._finalizer() self._finalizer()
@staticmethod def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
def _get_zmq_addresses(parallel_config: ParallelConfig, local_start_index: int, input_address: str,
spmd_mode: bool) -> tuple[str, str]: output_address: str,
"""Returns (input_address, output_address).""" executor_class: type[Executor], log_stats: bool):
dp_size = parallel_config.data_parallel_size """Self-contained client mode, launch engine and coordinator process
as needed."""
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local local_engine_count = parallel_config.data_parallel_size_local
start_index = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
if local_engine_count == dp_size or spmd_mode: if len(self.core_engines) > 1:
input_address = get_open_zmq_ipc_path() self.resources.coordinator = DPCoordinator(parallel_config)
output_address = get_open_zmq_ipc_path()
else: handshake_address = get_engine_client_zmq_addr(
host = parallel_config.data_parallel_master_ip local_only, host, parallel_config.data_parallel_rpc_port)
input_port = parallel_config.data_parallel_rpc_port
output_port = get_open_port()
input_address = get_tcp_uri(host, input_port)
output_address = get_tcp_uri(host, output_port)
return input_address, output_address
def _wait_for_engine_startup(self, output_address: str,
parallel_config: ParallelConfig):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket = zmq.Socket.shadow(self.input_socket)
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(self.core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(sync_input_socket, zmq.POLLIN)
proc_manager = self.resources.local_engine_manager
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != sync_input_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs(
) if proc_manager else {}
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, byteorder="little")
engine = next(
(e for e in self.core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
init_message = self.encoder.encode({
"output_socket_address": output_address,
"parallel_config": {
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
},
})
sync_input_socket.send_multipart((eng_identity, *init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state
== CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
cache_config = self.vllm_config.cache_config
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg['num_gpu_blocks']
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status, with zmq_socket_ctx(handshake_address, zmq.ROUTER,
"local" if local else "remote", eng_index) bind=True) as handshake_socket:
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
self.resources.local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
handshake_address=handshake_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=start_index,
local_start_index=local_start_index)
# Wait for engine core process(es) to start.
self._wait_for_engine_startup(handshake_socket, input_address,
output_address)
def _wait_for_engine_startup(self, handshake_socket: zmq.Socket,
input_address: str, output_address: str):
addresses = EngineZmqAddresses(
inputs=[input_address],
outputs=[output_address],
)
coordinator = self.resources.coordinator
if coordinator is not None:
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
wait_for_engine_startup(
handshake_socket,
addresses,
self.core_engines,
self.vllm_config.parallel_config,
self.vllm_config.cache_config,
self.resources.local_engine_manager,
coordinator.proc if coordinator else None,
)
def shutdown(self): def shutdown(self):
# Terminate background resources. # Terminate background resources.
...@@ -605,8 +559,8 @@ class SyncMPClient(MPClient): ...@@ -605,8 +559,8 @@ class SyncMPClient(MPClient):
try: try:
shutdown_socket.bind(shutdown_path) shutdown_socket.bind(shutdown_path)
poller = zmq.Poller() poller = zmq.Poller()
poller.register(shutdown_socket) poller.register(shutdown_socket, zmq.POLLIN)
poller.register(out_socket) poller.register(out_socket, zmq.POLLIN)
while True: while True:
socks = poller.poll() socks = poller.poll()
if not socks: if not socks:
...@@ -668,7 +622,7 @@ class SyncMPClient(MPClient): ...@@ -668,7 +622,7 @@ class SyncMPClient(MPClient):
future: Future[Any] = Future() future: Future[Any] = Future()
self.utility_results[call_id] = future self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY, self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args)) (0, call_id, method, args))
return future.result() return future.result()
...@@ -730,15 +684,21 @@ class SyncMPClient(MPClient): ...@@ -730,15 +684,21 @@ class SyncMPClient(MPClient):
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore.""" """Asyncio-compatible client for multi-proc EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], def __init__(self,
log_stats: bool): vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0):
super().__init__( super().__init__(
asyncio_mode=True, asyncio_mode=True,
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=log_stats, log_stats=log_stats,
client_addresses=client_addresses,
) )
self.client_index = client_index
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
Exception]]() Exception]]()
try: try:
...@@ -854,12 +814,13 @@ class AsyncMPClient(MPClient): ...@@ -854,12 +814,13 @@ class AsyncMPClient(MPClient):
future = asyncio.get_running_loop().create_future() future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(call_id, method, args))) (self.client_index, call_id, method, args)))
await self._send_input_message(message, engine, args) await self._send_input_message(message, engine, args)
self._ensure_output_queue_task() self._ensure_output_queue_task()
return await future return await future
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
request.client_index = self.client_index
await self._send_input(EngineCoreRequestType.ADD, request) await self._send_input(EngineCoreRequestType.ADD, request)
self._ensure_output_queue_task() self._ensure_output_queue_task()
...@@ -921,17 +882,120 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -921,17 +882,120 @@ class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel) """Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore.""" EngineCore."""
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], def __init__(self,
log_stats: bool): vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0):
self.current_wave = 0 self.current_wave = 0
self.engines_running = False self.engines_running = False
# To route aborts to the correct engine.
self.reqs_in_flight: dict[str, CoreEngine] = {} self.reqs_in_flight: dict[str, CoreEngine] = {}
super().__init__(vllm_config, executor_class, log_stats) super().__init__(vllm_config, executor_class, log_stats,
client_addresses, client_index)
assert len(self.core_engines) > 1 assert len(self.core_engines) > 1
# List of [waiting, running] pair per engine.
self.lb_engines: list[list[int]] = []
self.first_req_sock_addr = get_open_zmq_inproc_path()
self.first_req_send_socket = self.resources.first_req_send_socket = (
make_zmq_socket(self.ctx,
self.first_req_sock_addr,
zmq.PAIR,
bind=True))
try:
# If we are running in an asyncio event loop, start the stats task.
# Otherwise, it will be started lazily.
asyncio.get_running_loop()
self._ensure_stats_update_task()
except RuntimeError:
pass
def _ensure_stats_update_task(self):
resources = self.resources
if resources.stats_update_task is not None:
return
assert self.stats_update_address is not None
async def run_engine_stats_update_task():
with make_zmq_socket(self.ctx, self.stats_update_address,
zmq.XSUB) as socket, make_zmq_socket(
self.ctx,
self.first_req_sock_addr,
zmq.PAIR,
bind=False) as first_req_rcv_socket:
# Send subscription message.
await socket.send(b'\x01')
poller = zmq.asyncio.Poller()
poller.register(socket, zmq.POLLIN)
poller.register(first_req_rcv_socket, zmq.POLLIN)
while True:
events = await poller.poll()
if not self.engines_running and len(events) == 2 or (
events[0][0] == first_req_rcv_socket):
# Send a message to notify the coordinator that
# we're sending a request while the engines are
# paused, so that it can wake the others up
# (to run dummy EP loop).
self.engines_running = True
buf = first_req_rcv_socket.recv(
flags=zmq.NOBLOCK).result()
target_eng_index = int.from_bytes(buf, "little")
msg = msgspec.msgpack.encode(
(target_eng_index, self.current_wave))
await socket.send(msg)
buf = None
while True:
# Drain all stats events (we only care about latest).
future: asyncio.Future[bytes] = socket.recv(
flags=zmq.NOBLOCK)
if isinstance(future.exception(), zmq.Again):
break
buf = future.result()
if buf is None:
continue
# Update local load-balancing state.
counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave
self.engines_running = running
self.lb_engines = counts
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task())
def get_core_engine_for_request(self) -> CoreEngine:
if not self.lb_engines:
return self.core_engines[0]
# TODO use P2C alg for larger DP sizes
num_engines = len(self.lb_engines)
min_counts = [sys.maxsize, sys.maxsize]
eng_index = 0
for i in range(num_engines):
# Start from client_index to help with balancing when engines
# are empty.
idx = (self.client_index + i) % num_engines
counts = self.lb_engines[idx]
if counts < min_counts:
min_counts = counts
eng_index = idx
# Adjust local counts for better balancing between stats updates
# from the coordinator (which happen every 100ms).
if min_counts[0]:
min_counts[0] += 1
else:
min_counts[1] += 1
return self.core_engines[eng_index]
async def call_utility_async(self, method: str, *args) -> Any: async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned. # Only the result from the first engine is returned.
return (await asyncio.gather(*[ return (await asyncio.gather(*[
...@@ -940,62 +1004,30 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -940,62 +1004,30 @@ class DPAsyncMPClient(AsyncMPClient):
]))[0] ]))[0]
async def add_request_async(self, request: EngineCoreRequest) -> None: async def add_request_async(self, request: EngineCoreRequest) -> None:
self._ensure_stats_update_task()
request.current_wave = self.current_wave request.current_wave = self.current_wave
request.client_index = self.client_index
chosen_engine = self.get_core_engine_for_request() chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1
to_await = self._send_input(EngineCoreRequestType.ADD, request, to_await = self._send_input(EngineCoreRequestType.ADD, request,
chosen_engine) chosen_engine)
if not self.engines_running: if not self.engines_running:
# Send request to chosen engine and dp start loop # Notify coordinator that we're sending a request
# control message to all other engines. await self.first_req_send_socket.send(chosen_engine.identity)
self.engines_running = True
to_await = asyncio.gather(
to_await, # type: ignore[assignment]
*self._start_wave_coros(exclude_index=chosen_engine.index))
await to_await await to_await
self._ensure_output_queue_task() self._ensure_output_queue_task()
def get_core_engine_for_request(self) -> CoreEngine:
return min(self.core_engines, key=lambda e: e.num_reqs_in_flight)
@staticmethod @staticmethod
async def process_engine_outputs(self: "DPAsyncMPClient", async def process_engine_outputs(self: "DPAsyncMPClient",
outputs: EngineCoreOutputs): outputs: EngineCoreOutputs):
if self.reqs_in_flight: if outputs.finished_requests and self.reqs_in_flight:
for req_id in outputs.finished_requests or (): for req_id in outputs.finished_requests:
if engine := self.reqs_in_flight.pop(req_id, None): self.reqs_in_flight.pop(req_id, None)
engine.num_reqs_in_flight -= 1
if outputs.wave_complete is not None:
# Current wave is complete, move to next wave number
# and mark engines as paused.
if self.current_wave <= outputs.wave_complete:
self.current_wave = outputs.wave_complete + 1
self.engines_running = False
elif outputs.start_wave is not None and (
outputs.start_wave > self.current_wave or
(outputs.start_wave == self.current_wave
and not self.engines_running)):
# Engine received request for a non-current wave so we must ensure
# that other engines progress to the next wave.
self.current_wave = outputs.start_wave
self.engines_running = True
await asyncio.gather(*self._start_wave_coros(
exclude_index=outputs.engine_index))
def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]:
logger.debug("Sending start DP wave %d.", self.current_wave)
return [
self._send_input(EngineCoreRequestType.START_DP_WAVE,
self.current_wave, engine)
for engine in self.core_engines if engine.index != exclude_index
]
async def abort_requests_async(self, request_ids: list[str]) -> None: async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids: if not request_ids:
......
...@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig ...@@ -12,13 +12,12 @@ from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5.0
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
...@@ -35,7 +34,7 @@ class StatLoggerBase(ABC): ...@@ -35,7 +34,7 @@ class StatLoggerBase(ABC):
... ...
@abstractmethod @abstractmethod
def record(self, scheduler_stats: SchedulerStats, def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]): iteration_stats: Optional[IterationStats]):
... ...
...@@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -78,20 +77,22 @@ class LoggingStatLogger(StatLoggerBase):
# Compute summary metrics for tracked stats # Compute summary metrics for tracked stats
return float(np.sum(tracked_stats) / (now - self.last_log_time)) return float(np.sum(tracked_stats) / (now - self.last_log_time))
def record(self, scheduler_stats: SchedulerStats, def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]): iteration_stats: Optional[IterationStats]):
"""Log Stats to standard output.""" """Log Stats to standard output."""
if iteration_stats: if iteration_stats:
self._track_iteration_stats(iteration_stats) self._track_iteration_stats(iteration_stats)
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) if scheduler_stats is not None:
self.prefix_caching_metrics.observe(
scheduler_stats.prefix_cache_stats)
if scheduler_stats.spec_decoding_stats is not None: if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_logging.observe( self.spec_decoding_logging.observe(
scheduler_stats.spec_decoding_stats) scheduler_stats.spec_decoding_stats)
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
def log(self): def log(self):
now = time.monotonic() now = time.monotonic()
...@@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -131,10 +132,11 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging.log(log_fn=log_fn) self.spec_decoding_logging.log(log_fn=log_fn)
def log_engine_initialized(self): def log_engine_initialized(self):
logger.info( if self.vllm_config.cache_config.num_gpu_blocks:
"vllm cache_config_info with initialization " \ logger.info(
"after num_gpu_blocks is: %d", "Engine %03d: vllm cache_config_info with initialization "
self.vllm_config.cache_config.num_gpu_blocks) "after num_gpu_blocks is: %d", self.engine_index,
self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(StatLoggerBase): class PrometheusStatLogger(StatLoggerBase):
...@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -144,7 +146,8 @@ class PrometheusStatLogger(StatLoggerBase):
_spec_decoding_cls = SpecDecodingProm _spec_decoding_cls = SpecDecodingProm
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self._unregister_vllm_metrics()
unregister_vllm_metrics()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.engine_index = engine_index self.engine_index = engine_index
# Use this flag to hide metrics that were deprecated in # Use this flag to hide metrics that were deprecated in
...@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -169,11 +172,13 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_scheduler_running = self._gauge_cls( self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running", name="vllm:num_requests_running",
documentation="Number of requests in model execution batches.", documentation="Number of requests in model execution batches.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.gauge_scheduler_waiting = self._gauge_cls( self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting", name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.", documentation="Number of requests waiting to be processed.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
# #
...@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -182,6 +187,7 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_gpu_cache_usage = self._gauge_cls( self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc", name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
multiprocess_mode="mostrecent",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_gpu_prefix_cache_queries = self._counter_cls( self.counter_gpu_prefix_cache_queries = self._counter_cls(
...@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -242,6 +248,9 @@ class PrometheusStatLogger(StatLoggerBase):
buckets=build_1_2_5_buckets(max_model_len), buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
# See: https://github.com/vllm-project/vllm/pull/18053
self.histogram_iteration_tokens = \ self.histogram_iteration_tokens = \
self._histogram_cls( self._histogram_cls(
name="vllm:iteration_tokens_total", name="vllm:iteration_tokens_total",
...@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -340,6 +349,9 @@ class PrometheusStatLogger(StatLoggerBase):
# #
# LoRA metrics # LoRA metrics
# #
# TODO: This metric might be incorrect in case of using multiple
# api_server counts which uses prometheus mp.
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
if vllm_config.lora_config is not None: if vllm_config.lora_config is not None:
self.labelname_max_lora = "max_lora" self.labelname_max_lora = "max_lora"
...@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -350,13 +362,16 @@ class PrometheusStatLogger(StatLoggerBase):
self._gauge_cls( self._gauge_cls(
name="vllm:lora_requests_info", name="vllm:lora_requests_info",
documentation="Running stats on lora requests.", documentation="Running stats on lora requests.",
multiprocess_mode="sum",
labelnames=[ labelnames=[
self.labelname_max_lora, self.labelname_max_lora,
self.labelname_waiting_lora_adapters, self.labelname_waiting_lora_adapters,
self.labelname_running_lora_adapters, self.labelname_running_lora_adapters,
]) ],
)
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
metrics_info = config_obj.metrics_info() metrics_info = config_obj.metrics_info()
metrics_info["engine"] = self.engine_index metrics_info["engine"] = self.engine_index
...@@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -372,25 +387,28 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge = self._gauge_cls( info_gauge = self._gauge_cls(
name=name, name=name,
documentation=documentation, documentation=documentation,
labelnames=metrics_info.keys()).labels(**metrics_info) multiprocess_mode="mostrecent",
labelnames=metrics_info.keys(),
).labels(**metrics_info)
info_gauge.set(1) info_gauge.set(1)
def record(self, scheduler_stats: SchedulerStats, def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]): iteration_stats: Optional[IterationStats]):
"""Log to prometheus.""" """Log to prometheus."""
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) if scheduler_stats is not None:
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage) self.gauge_gpu_cache_usage.set(scheduler_stats.gpu_cache_usage)
self.counter_gpu_prefix_cache_queries.inc( self.counter_gpu_prefix_cache_queries.inc(
scheduler_stats.prefix_cache_stats.queries) scheduler_stats.prefix_cache_stats.queries)
self.counter_gpu_prefix_cache_hits.inc( self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits) scheduler_stats.prefix_cache_stats.hits)
if scheduler_stats.spec_decoding_stats is not None: if scheduler_stats.spec_decoding_stats is not None:
self.spec_decoding_prom.observe( self.spec_decoding_prom.observe(
scheduler_stats.spec_decoding_stats) scheduler_stats.spec_decoding_stats)
if iteration_stats is None: if iteration_stats is None:
return return
...@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -445,13 +463,6 @@ class PrometheusStatLogger(StatLoggerBase):
self.gauge_lora_info.labels(**lora_info_labels)\ self.gauge_lora_info.labels(**lora_info_labels)\
.set_to_current_time() .set_to_current_time()
@staticmethod
def _unregister_vllm_metrics():
# Unregister any existing vLLM collectors (for CI/CD
for collector in list(prometheus_client.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
prometheus_client.REGISTRY.unregister(collector)
def log_engine_initialized(self): def log_engine_initialized(self):
self.log_metrics_info("cache_config", self.vllm_config.cache_config) self.log_metrics_info("cache_config", self.vllm_config.cache_config)
......
# SPDX-License-Identifier: Apache-2.0
import os
import tempfile
from typing import Optional
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
from vllm.logger import init_logger
logger = init_logger(__name__)
# Global temporary directory for prometheus multiprocessing
_prometheus_multiproc_dir: Optional[tempfile.TemporaryDirectory] = None
def setup_multiprocess_prometheus():
"""Set up prometheus multiprocessing directory if not already configured.
"""
global _prometheus_multiproc_dir
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
_prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name
logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s",
_prometheus_multiproc_dir.name)
else:
logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
def get_prometheus_registry():
"""Get the appropriate prometheus registry based on multiprocessing
configuration.
Returns:
Registry: A prometheus registry
"""
if os.getenv("PROMETHEUS_MULTIPROC_DIR") is not None:
logger.debug("Using multiprocess registry for prometheus metrics")
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
return registry
return REGISTRY
def unregister_vllm_metrics():
"""Unregister any existing vLLM collectors from the prometheus registry.
This is useful for testing and CI/CD where metrics may be registered
multiple times across test runs.
Also, in case of multiprocess, we need to unregister the metrics from the
global registry.
"""
registry = REGISTRY
# Unregister any existing vLLM collectors
for collector in list(registry._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
registry.unregister(collector)
def shutdown_prometheus():
"""Shutdown prometheus metrics."""
try:
pid = os.getpid()
multiprocess.mark_process_dead(pid)
logger.debug("Marked Prometheus metrics for process %d as dead", pid)
except Exception as e:
logger.error("Error during metrics cleanup: %s", str(e))
...@@ -26,12 +26,13 @@ class Request: ...@@ -26,12 +26,13 @@ class Request:
multi_modal_placeholders: Optional[list[PlaceholderRange]], multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams, sampling_params: SamplingParams,
eos_token_id: Optional[int], eos_token_id: Optional[int],
arrival_time: float, client_index: int = 0,
lora_request: Optional["LoRARequest"] = None, lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.client_index = client_index
self.sampling_params = sampling_params self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request. # Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
...@@ -90,13 +91,13 @@ class Request: ...@@ -90,13 +91,13 @@ class Request:
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
multi_modal_inputs=request.mm_inputs, multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes, multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders, multi_modal_placeholders=request.mm_placeholders,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
eos_token_id=request.eos_token_id, eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request, lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest( structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params), sampling_params=request.sampling_params),
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import argparse
import multiprocessing
import time import time
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection from multiprocessing import Process, connection
from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, from multiprocessing.process import BaseProcess
overload) from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
import msgspec
import torch import torch
import zmq
from vllm.config import VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import get_mp_context, kill_process_tree from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
get_tcp_uri, kill_process_tree)
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator
logger = init_logger(__name__) logger = init_logger(__name__)
T = TypeVar("T") T = TypeVar("T")
STARTUP_POLL_PERIOD_MS = 10000
class ConstantList(Generic[T], Sequence): class ConstantList(Generic[T], Sequence):
...@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence): ...@@ -95,6 +105,78 @@ class ConstantList(Generic[T], Sequence):
return f"ConstantList({self._x})" return f"ConstantList({self._x})"
def get_engine_client_zmq_addr(local_only: bool,
host: str,
port: int = 0) -> str:
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
host, port or get_open_port()))
class APIServerProcessManager:
"""Manages a group of API server processes.
Handles creation, monitoring, and termination of API server worker
processes. Also monitors extra processes to check if they are healthy.
"""
def __init__(
self,
target_server_fn: Callable,
listen_address: str,
sock: Any,
args: argparse.Namespace,
num_servers: int,
input_addresses: list[str],
output_addresses: list[str],
stats_update_address: Optional[str] = None,
):
"""Initialize and start API server worker processes.
Args:
target_server_fn: Function to call for each API server process
listen_address: Address to listen for client connections
sock: Socket for client connections
args: Command line arguments
num_servers: Number of API server processes to start
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
"""
self.listen_address = listen_address
self.sock = sock
self.args = args
# Start API servers
spawn_context = multiprocessing.get_context("spawn")
self.processes: list[BaseProcess] = []
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
output_addresses):
client_config = {
"input_address": in_addr,
"output_address": out_addr,
"client_index": i
}
if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address
proc = spawn_context.Process(target=target_server_fn,
name=f"ApiServer_{i}",
args=(listen_address, sock, args,
client_config))
self.processes.append(proc)
proc.start()
logger.info("Started %d API server processes", len(self.processes))
# Shutdown only the API server processes on garbage collection
# The extra processes are managed by their owners
self._finalizer = weakref.finalize(self, shutdown, self.processes)
def close(self) -> None:
self._finalizer()
class CoreEngineProcManager: class CoreEngineProcManager:
""" """
Utility class to handle creation, readiness, and shutdown Utility class to handle creation, readiness, and shutdown
...@@ -109,7 +191,7 @@ class CoreEngineProcManager: ...@@ -109,7 +191,7 @@ class CoreEngineProcManager:
local_start_index: int, local_start_index: int,
vllm_config: VllmConfig, vllm_config: VllmConfig,
on_head_node: bool, on_head_node: bool,
input_address: str, handshake_address: str,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
): ):
...@@ -117,12 +199,12 @@ class CoreEngineProcManager: ...@@ -117,12 +199,12 @@ class CoreEngineProcManager:
common_kwargs = { common_kwargs = {
"vllm_config": vllm_config, "vllm_config": vllm_config,
"on_head_node": on_head_node, "on_head_node": on_head_node,
"input_address": input_address, "handshake_address": handshake_address,
"executor_class": executor_class, "executor_class": executor_class,
"log_stats": log_stats, "log_stats": log_stats,
} }
self.processes: list[Process] = [] self.processes: list[BaseProcess] = []
for index in range(local_engine_count): for index in range(local_engine_count):
local_index = local_start_index + index local_index = local_start_index + index
global_index = start_index + index global_index = start_index + index
...@@ -135,8 +217,7 @@ class CoreEngineProcManager: ...@@ -135,8 +217,7 @@ class CoreEngineProcManager:
"local_dp_rank": local_index, "local_dp_rank": local_index,
})) }))
self._finalizer = weakref.finalize(self, shutdown, self.processes, self._finalizer = weakref.finalize(self, shutdown, self.processes)
input_address)
try: try:
for proc in self.processes: for proc in self.processes:
proc.start() proc.start()
...@@ -164,9 +245,199 @@ class CoreEngineProcManager: ...@@ -164,9 +245,199 @@ class CoreEngineProcManager:
} }
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
@dataclass
class EngineZmqAddresses:
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
@dataclass
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
def wait_for_engine_startup(
handshake_socket: zmq.Socket,
addresses: EngineZmqAddresses,
core_engines: list[CoreEngine],
parallel_config: ParallelConfig,
cache_config: CacheConfig,
proc_manager: Optional[CoreEngineProcManager],
coord_process: Optional[Process],
):
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
if coord_process is not None:
poller.register(coord_process.sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != handshake_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
if coord_process is not None and coord_process.exitcode is not None:
finished[coord_process.name] = coord_process.exitcode
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, "little")
engine = next((e for e in core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata(
addresses=addresses,
parallel_config={
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
}))
handshake_socket.send_multipart((eng_identity, init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)
def wait_for_completion_or_failure(
api_server_manager: APIServerProcessManager,
local_engine_manager: Optional[CoreEngineProcManager] = None,
coordinator: Optional["DPCoordinator"] = None) -> None:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
"""
try:
logger.info("Waiting for API servers to complete ...")
# Create a mapping of sentinels to their corresponding processes
# for efficient lookup
sentinel_to_proc: dict[Any, BaseProcess] = {
proc.sentinel: proc
for proc in api_server_manager.processes
}
if coordinator:
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
if local_engine_manager:
for proc in local_engine_manager.processes:
sentinel_to_proc[proc.sentinel] = proc
# Check if any process terminates
while sentinel_to_proc:
# Wait for any process to terminate
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
# Process any terminated processes
for sentinel in ready_sentinels:
proc = sentinel_to_proc.pop(sentinel)
# Check if process exited with error
if proc.exitcode != 0:
raise RuntimeError(
f"Process {proc.name} (PID: {proc.pid}) "
f"died with exit code {proc.exitcode}")
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down API servers...")
except Exception as e:
logger.exception("Exception occurred while running API servers: %s",
str(e))
raise
finally:
logger.info("Terminating remaining processes ...")
api_server_manager.close()
if coordinator:
coordinator.close()
if local_engine_manager:
local_engine_manager.close()
# Note(rob): shutdown function cannot be a bound method, # Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the objedecoupct. # else the gc cannot collect the object.
def shutdown(procs: list[Process], input_address: str): def shutdown(procs: list[BaseProcess]):
# Shutdown the process. # Shutdown the process.
for proc in procs: for proc in procs:
if proc.is_alive(): if proc.is_alive():
...@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str): ...@@ -185,12 +456,6 @@ def shutdown(procs: list[Process], input_address: str):
if proc.is_alive() and (pid := proc.pid) is not None: if proc.is_alive() and (pid := proc.pid) is not None:
kill_process_tree(pid) kill_process_tree(pid)
# Remove zmq ipc socket files.
if input_address.startswith("ipc://"):
socket_file = input_address[len("ipc://"):]
if os and os.path.exists(socket_file):
os.remove(socket_file)
def bind_kv_cache( def bind_kv_cache(
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment