Commit cc7f22a8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.1' into v0.9.1-ori

parents b9ea0c09 b6553be1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ctypes
import json
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.device_communicators.base_device_communicator import (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle
import time
......@@ -27,6 +28,43 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
logger = init_logger(__name__)
class SpinTimer:
def record_activity(self):
pass
def spin(self):
sched_yield()
class SpinSleepTimer(SpinTimer):
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when vllm does nothing. This would lead to more
CPU thermal headroom when a request eventually comes, especially when
multiple GPUs are connected as each GPU would otherwise pin one thread at
100% CPU usage.
The simplest solution is to reduce polling frequency when there is no
activity for a certain period of time.
"""
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
self.last_activity = time.monotonic()
self.busy_loop_s = busy_loop_s
self.wait_sleep_s = wait_sleep_s
def record_activity(self):
self.last_activity = time.monotonic()
def spin(self):
curr_time = time.monotonic()
if curr_time >= self.last_activity + self.busy_loop_s:
time.sleep(self.wait_sleep_s)
else:
sched_yield()
class ShmRingBuffer:
def __init__(self,
......@@ -41,7 +79,7 @@ class ShmRingBuffer:
of items that can be stored in the buffer are known in advance.
In this case, we don't need to synchronize the access to
the buffer.
Buffer memory layout:
data metadata
| |
......@@ -237,6 +275,7 @@ class MessageQueue:
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self._read_spin_timer = SpinTimer()
self.handle = Handle(
local_reader_ranks=local_reader_ranks,
......@@ -275,6 +314,9 @@ class MessageQueue:
self.local_socket.connect(socket_addr)
self.remote_socket = None
self._read_spin_timer = SpinSleepTimer(
) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
else:
self.buffer = None # type: ignore
self.current_idx = -1
......@@ -406,7 +448,7 @@ class MessageQueue:
# we need to wait until it is written
# Release the processor to other threads
sched_yield()
self._read_spin_timer.spin()
# if we wait for a long time, log a message
if (time.monotonic() - start_time
......@@ -437,6 +479,8 @@ class MessageQueue:
metadata_buffer[self.local_reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
self._read_spin_timer.record_activity()
break
def enqueue(self, obj, timeout: Optional[float] = None):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import queue
import threading
......@@ -27,6 +28,7 @@ class EventBatch(
):
ts: float
events: list[Any]
data_parallel_rank: Optional[int] = None
class KVCacheEvent(
......@@ -59,7 +61,22 @@ class KVEventBatch(EventBatch):
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches."""
"""Lightweight publisher for EventBatch batches with data parallelism
support.
In data parallel setups, each DP rank runs its own EventPublisher instance
to avoid duplicate events and ensure proper event attribution:
- Each DP rank creates a separate publisher
- Publishers automatically annotate events with their data_parallel_rank
- This allows consumers to distinguish events from different DP ranks
The publisher is responsible for adding DP metadata since the scheduler
operates independently of DP topology and shouldn't need DP awareness.
"""
def __init__(self, data_parallel_rank: int = 0) -> None:
self._data_parallel_rank = data_parallel_rank
@abstractmethod
def publish(self, events: EventBatch) -> None:
......@@ -112,6 +129,7 @@ class ZmqEventPublisher(EventPublisher):
def __init__(
self,
data_parallel_rank: int,
endpoint: str = "tcp://*:5557",
replay_endpoint: Optional[str] = None,
buffer_steps: int = 10_000,
......@@ -120,6 +138,7 @@ class ZmqEventPublisher(EventPublisher):
topic: str = "",
) -> None:
# Storage
super().__init__(data_parallel_rank)
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
......@@ -127,8 +146,11 @@ class ZmqEventPublisher(EventPublisher):
self._ctx = zmq.Context.instance()
self._pub: Optional[zmq.Socket] = None
self._replay: Optional[zmq.Socket] = None
self._endpoint = endpoint
self._replay_endpoint = replay_endpoint
self._dp_rank = data_parallel_rank
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
self._replay_endpoint = self.offset_endpoint_port(
replay_endpoint, self._dp_rank)
self._hwm = hwm
self._socket_setup()
......@@ -148,6 +170,8 @@ class ZmqEventPublisher(EventPublisher):
def publish(self, events: EventBatch) -> None:
if not self._running:
raise RuntimeError("Publisher is closed")
if events.data_parallel_rank is None:
events.data_parallel_rank = self._data_parallel_rank
self._event_queue.put(events)
def shutdown(self) -> None:
......@@ -190,11 +214,12 @@ class ZmqEventPublisher(EventPublisher):
self._pub.set_hwm(self._hwm)
# Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention
if ("*" in self._endpoint or "::" in self._endpoint
or self._endpoint.startswith("ipc://")
or self._endpoint.startswith("inproc://")):
if (self._endpoint is not None
and ("*" in self._endpoint or "::" in self._endpoint
or self._endpoint.startswith("ipc://")
or self._endpoint.startswith("inproc://"))):
self._pub.bind(self._endpoint)
else:
elif self._endpoint is not None:
self._pub.connect(self._endpoint)
# Set up replay socket: use ROUTER
......@@ -265,6 +290,38 @@ class ZmqEventPublisher(EventPublisher):
# receiving payload is (-1, b""")
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
@staticmethod
def offset_endpoint_port(endpoint: Optional[str],
data_parallel_rank: int) -> Optional[str]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if not endpoint or data_parallel_rank == 0:
return endpoint
if "inproc" in endpoint:
return f"{endpoint}_dp{data_parallel_rank}"
if "tcp" in endpoint:
if endpoint and ":" in endpoint:
# Get everything after the last colon (the port)
last_colon_idx = endpoint.rfind(":")
base_addr = endpoint[:last_colon_idx]
base_port = int(endpoint[last_colon_idx + 1:])
new_port = base_port + data_parallel_rank
return f"{base_addr}:{new_port}"
return endpoint
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
class EventPublisherFactory:
_registry: dict[str, Callable[..., EventPublisher]] = {
......@@ -280,7 +337,9 @@ class EventPublisherFactory:
cls._registry[name] = ctor
@classmethod
def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher:
def create(cls,
config: Optional[KVEventsConfig],
data_parallel_rank: int = 0) -> EventPublisher:
"""Create publisher from a config mapping."""
if not config:
return NullEventPublisher()
......@@ -293,4 +352,5 @@ class EventPublisherFactory:
constructor = cls._registry[kind]
except KeyError as exc:
raise ValueError(f"Unknown event publisher '{kind}'") from exc
return constructor(**config_dict)
return constructor(data_parallel_rank=data_parallel_rank,
**config_dict)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_transfer_state import (
KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from typing import TYPE_CHECKING, Callable
......@@ -70,7 +71,8 @@ class KVConnectorFactory:
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
assert issubclass(connector_cls, KVConnectorBase_V1)
logger.info("Creating v1 connector with name: %s", connector_name)
logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_name, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
LMCache KV Cache Connector for Distributed Machine Learning Inference
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
MooncakeStore Connector for Distributed Machine Learning Inference
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KV cache helper for store.
"""
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger(__name__)
......@@ -89,3 +91,18 @@ class model_aware_kv_ops_helper:
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if vllm_config.model_config is None:
logger.warning("Unable to detect current VLLM config. " \
"Defaulting to NHD kv cache layout.")
else:
use_mla = vllm_config.model_config.use_mla
if not use_mla and kv_config.kv_connector == "NixlConnector":
logger.info("NixlConnector detected. Setting KV cache " \
"layout to HND for better xfer performance.")
return "HND"
return "NHD"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
communication in vLLM v1
......@@ -7,9 +8,15 @@ The class provides the following primitives:
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save KV cache.
get_num_new_matched_tokens() - get number of new tokens
that exist in the remote KV cache
that exist in the remote KV cache. Might be called multiple
times for a given request and should be side-effect free.
update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager.
request_finished() - called when a request is finished, with
the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
......@@ -18,6 +25,9 @@ The class provides the following primitives:
save_kv_layer() - starts saving KV for layer i (maybe async)
wait_for_save() - blocks until all saves are done
get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
"""
import enum
......@@ -183,7 +193,8 @@ class KVConnectorBase_V1(ABC):
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
......@@ -214,7 +225,8 @@ class KVConnectorBase_V1(ABC):
- The number of tokens that can be loaded from the
external KV cache beyond what is already computed.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps).
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
"""
pass
......@@ -224,6 +236,18 @@ class KVConnectorBase_V1(ABC):
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
If get_num_new_matched_tokens previously returned True for a
request, this function may be called twice for that same request -
first when blocks are allocated for the connector tokens to be
asynchronously loaded into, and second when any additional blocks
are allocated, after the load/transfer is complete.
Args:
request (Request): the request object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
"""
pass
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
......@@ -11,12 +12,12 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
......@@ -50,8 +51,9 @@ class MultiConnector(KVConnectorBase_V1):
self._connectors.append(
KVConnectorFactory.create_connector_v1(temp_config, role))
# A mapping from request id to the connector that is assigned to it.
self._requests_to_connector: dict[str, KVConnectorBase_V1] = {}
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
self._requests_to_connector: dict[str, int] = {}
# Keeps track of *additional* remaining async saves (beyond 1) to be
# finished per request. Not needed for async loads since we only allow
......@@ -135,25 +137,31 @@ class MultiConnector(KVConnectorBase_V1):
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
for c in self._connectors:
to_return = (0, False)
for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens)
# The first connector that has new matched tokens will be assigned
# to this request.
if toks > 0:
self._requests_to_connector[request.request_id] = c
return toks, load_async
return 0, False
if to_return[0] == 0 and toks > 0:
self._requests_to_connector[request.request_id] = i
to_return = (toks, load_async)
return to_return
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
# If the request is not assigned to any connector, we do nothing.
if request.request_id not in self._requests_to_connector:
return
# We assume that the request is assigned to only one connector.
c = self._requests_to_connector.pop(request.request_id)
c.update_state_after_alloc(request, blocks, num_external_tokens)
chosen_connector = self._requests_to_connector.get(
request.request_id, -1)
empty_blocks = blocks.new_empty()
for i, c in enumerate(self._connectors):
if i == chosen_connector:
# Forward call to the chosen connector (if any).
c.update_state_after_alloc(request, blocks,
num_external_tokens)
else:
# Call with empty blocks for other connectors.
c.update_state_after_alloc(request, empty_blocks, 0)
def build_connector_meta(
self,
......@@ -169,7 +177,7 @@ class MultiConnector(KVConnectorBase_V1):
def request_finished(
self,
request: "Request",
blocks: "KVCacheBlocks",
blocks: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
async_saves = 0
kv_txfer_params = None
......@@ -186,4 +194,8 @@ class MultiConnector(KVConnectorBase_V1):
kv_txfer_params = txfer_params
if async_saves > 1:
self._extra_async_saves[request.request_id] = async_saves - 1
# Clean up other state for this request.
self._requests_to_connector.pop(request.request_id, None)
return async_saves > 0, kv_txfer_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment